summary refs log tree commit diff
path: root/synapse/storage/databases/main/deviceinbox.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/deviceinbox.py')
-rw-r--r--synapse/storage/databases/main/deviceinbox.py94
1 files changed, 49 insertions, 45 deletions
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 76ec954f44..1f6e995c4f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -16,8 +16,6 @@
 import logging
 from typing import List, Tuple
 
-from twisted.internet import defer
-
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
@@ -31,24 +29,31 @@ class DeviceInboxWorkerStore(SQLBaseStore):
     def get_to_device_stream_token(self):
         return self._device_inbox_id_gen.get_current_token()
 
-    def get_new_messages_for_device(
-        self, user_id, device_id, last_stream_id, current_stream_id, limit=100
-    ):
+    async def get_new_messages_for_device(
+        self,
+        user_id: str,
+        device_id: str,
+        last_stream_id: int,
+        current_stream_id: int,
+        limit: int = 100,
+    ) -> Tuple[List[dict], int]:
         """
         Args:
-            user_id(str): The recipient user_id.
-            device_id(str): The recipient device_id.
-            current_stream_id(int): The current position of the to device
+            user_id: The recipient user_id.
+            device_id: The recipient device_id.
+            last_stream_id: The last stream ID checked.
+            current_stream_id: The current position of the to device
                 message stream.
+            limit: The maximum number of messages to retrieve.
+
         Returns:
-            Deferred ([dict], int): List of messages for the device and where
-                in the stream the messages got to.
+            A list of messages for the device and where in the stream the messages got to.
         """
         has_changed = self._device_inbox_stream_cache.has_entity_changed(
             user_id, last_stream_id
         )
         if not has_changed:
-            return defer.succeed(([], current_stream_id))
+            return ([], current_stream_id)
 
         def get_new_messages_for_device_txn(txn):
             sql = (
@@ -69,20 +74,22 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 stream_pos = current_stream_id
             return messages, stream_pos
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_new_messages_for_device", get_new_messages_for_device_txn
         )
 
     @trace
-    @defer.inlineCallbacks
-    def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
+    async def delete_messages_for_device(
+        self, user_id: str, device_id: str, up_to_stream_id: int
+    ) -> int:
         """
         Args:
-            user_id(str): The recipient user_id.
-            device_id(str): The recipient device_id.
-            up_to_stream_id(int): Where to delete messages up to.
+            user_id: The recipient user_id.
+            device_id: The recipient device_id.
+            up_to_stream_id: Where to delete messages up to.
+
         Returns:
-            A deferred that resolves to the number of messages deleted.
+            The number of messages deleted.
         """
         # If we have cached the last stream id we've deleted up to, we can
         # check if there is likely to be anything that needs deleting
@@ -109,7 +116,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id, device_id, up_to_stream_id))
             return txn.rowcount
 
-        count = yield self.db_pool.runInteraction(
+        count = await self.db_pool.runInteraction(
             "delete_messages_for_device", delete_messages_for_device_txn
         )
 
@@ -128,9 +135,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         return count
 
     @trace
-    def get_new_device_msgs_for_remote(
+    async def get_new_device_msgs_for_remote(
         self, destination, last_stream_id, current_stream_id, limit
-    ):
+    ) -> Tuple[List[dict], int]:
         """
         Args:
             destination(str): The name of the remote server.
@@ -139,8 +146,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             current_stream_id(int|long): The current position of the device
                 message stream.
         Returns:
-            Deferred ([dict], int|long): List of messages for the device and where
-                in the stream the messages got to.
+            A list of messages for the device and where in the stream the messages got to.
         """
 
         set_tag("destination", destination)
@@ -153,11 +159,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         )
         if not has_changed or last_stream_id == current_stream_id:
             log_kv({"message": "No new messages in stream"})
-            return defer.succeed(([], current_stream_id))
+            return ([], current_stream_id)
 
         if limit <= 0:
             # This can happen if we run out of room for EDUs in the transaction.
-            return defer.succeed(([], last_stream_id))
+            return ([], last_stream_id)
 
         @trace
         def get_new_messages_for_remote_destination_txn(txn):
@@ -178,7 +184,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 stream_pos = current_stream_id
             return messages, stream_pos
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_new_device_msgs_for_remote",
             get_new_messages_for_remote_destination_txn,
         )
@@ -290,16 +296,15 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
             self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
         )
 
-    @defer.inlineCallbacks
-    def _background_drop_index_device_inbox(self, progress, batch_size):
+    async def _background_drop_index_device_inbox(self, progress, batch_size):
         def reindex_txn(conn):
             txn = conn.cursor()
             txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
             txn.close()
 
-        yield self.db_pool.runWithConnection(reindex_txn)
+        await self.db_pool.runWithConnection(reindex_txn)
 
-        yield self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
+        await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
 
         return 1
 
@@ -320,21 +325,21 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
         )
 
     @trace
-    @defer.inlineCallbacks
-    def add_messages_to_device_inbox(
-        self, local_messages_by_user_then_device, remote_messages_by_destination
-    ):
+    async def add_messages_to_device_inbox(
+        self,
+        local_messages_by_user_then_device: dict,
+        remote_messages_by_destination: dict,
+    ) -> int:
         """Used to send messages from this server.
 
         Args:
-            sender_user_id(str): The ID of the user sending these messages.
-            local_messages_by_user_and_device(dict):
+            local_messages_by_user_and_device:
                 Dictionary of user_id to device_id to message.
-            remote_messages_by_destination(dict):
+            remote_messages_by_destination:
                 Dictionary of destination server_name to the EDU JSON to send.
+
         Returns:
-            A deferred stream_id that resolves when the messages have been
-            inserted.
+            The new stream_id.
         """
 
         def add_messages_txn(txn, now_ms, stream_id):
@@ -359,7 +364,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
 
         with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
             )
             for user_id in local_messages_by_user_then_device.keys():
@@ -371,10 +376,9 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
 
         return self._device_inbox_id_gen.get_current_token()
 
-    @defer.inlineCallbacks
-    def add_messages_from_remote_to_device_inbox(
-        self, origin, message_id, local_messages_by_user_then_device
-    ):
+    async def add_messages_from_remote_to_device_inbox(
+        self, origin: str, message_id: str, local_messages_by_user_then_device: dict
+    ) -> int:
         def add_messages_txn(txn, now_ms, stream_id):
             # Check if we've already inserted a matching message_id for that
             # origin. This can happen if the origin doesn't receive our
@@ -409,7 +413,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
 
         with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "add_messages_from_remote_to_device_inbox",
                 add_messages_txn,
                 now_ms,