summary refs log tree commit diff
path: root/synapse/storage/databases/main/deviceinbox.py
diff options
context:
space:
mode:
authorDirk Klimpel <5740567+dklimpel@users.noreply.github.com>2022-04-27 14:05:00 +0200
committerGitHub <noreply@github.com>2022-04-27 13:05:00 +0100
commitb76f1a4d5f918def1f643910939b80e9e035e07f (patch)
treeb0492ace0e54340b0b40d990a298c2c249274427 /synapse/storage/databases/main/deviceinbox.py
parentBound ephemeral events by key (#12544) (diff)
downloadsynapse-b76f1a4d5f918def1f643910939b80e9e035e07f.tar.xz
Add some type hints to datastore (#12485)
Diffstat (limited to 'synapse/storage/databases/main/deviceinbox.py')
-rw-r--r--synapse/storage/databases/main/deviceinbox.py79
1 files changed, 56 insertions, 23 deletions
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index b4a1b041b1..599b418383 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,17 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    cast,
+)
 
 from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -118,7 +128,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             prefilled_cache=device_outbox_prefill,
         )
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[ToDeviceStream.ToDeviceStreamRow],
+    ) -> None:
         if stream_name == ToDeviceStream.NAME:
             # If replication is happening than postgres must be being used.
             assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
@@ -134,7 +150,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                     )
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    def get_to_device_stream_token(self):
+    def get_to_device_stream_token(self) -> int:
         return self._device_inbox_id_gen.get_current_token()
 
     async def get_messages_for_user_devices(
@@ -301,7 +317,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         if not user_ids_to_query:
             return {}, to_stream_id
 
-        def get_device_messages_txn(txn: LoggingTransaction):
+        def get_device_messages_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
             # Build a query to select messages from any of the given devices that
             # are between the given stream id bounds.
 
@@ -428,7 +446,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 log_kv({"message": "No changes in cache since last check"})
                 return 0
 
-        def delete_messages_for_device_txn(txn):
+        def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
             sql = (
                 "DELETE FROM device_inbox"
                 " WHERE user_id = ? AND device_id = ?"
@@ -455,15 +473,14 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
     @trace
     async def get_new_device_msgs_for_remote(
-        self, destination, last_stream_id, current_stream_id, limit
-    ) -> Tuple[List[dict], int]:
+        self, destination: str, last_stream_id: int, current_stream_id: int, limit: int
+    ) -> Tuple[List[JsonDict], int]:
         """
         Args:
-            destination(str): The name of the remote server.
-            last_stream_id(int|long): The last position of the device message stream
+            destination: The name of the remote server.
+            last_stream_id: The last position of the device message stream
                 that the server sent up to.
-            current_stream_id(int|long): The current position of the device
-                message stream.
+            current_stream_id: The current position of the device message stream.
         Returns:
             A list of messages for the device and where in the stream the messages got to.
         """
@@ -485,7 +502,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             return [], last_stream_id
 
         @trace
-        def get_new_messages_for_remote_destination_txn(txn):
+        def get_new_messages_for_remote_destination_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[JsonDict], int]:
             sql = (
                 "SELECT stream_id, messages_json FROM device_federation_outbox"
                 " WHERE destination = ?"
@@ -527,7 +546,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             up_to_stream_id: Where to delete messages up to.
         """
 
-        def delete_messages_for_remote_destination_txn(txn):
+        def delete_messages_for_remote_destination_txn(txn: LoggingTransaction) -> None:
             sql = (
                 "DELETE FROM device_federation_outbox"
                 " WHERE destination = ?"
@@ -566,7 +585,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_new_device_messages_txn(txn):
+        def get_all_new_device_messages_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             # We limit like this as we might have multiple rows per stream_id, and
             # we want to make sure we always get all entries for any stream_id
             # we return.
@@ -607,8 +628,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
     @trace
     async def add_messages_to_device_inbox(
         self,
-        local_messages_by_user_then_device: dict,
-        remote_messages_by_destination: dict,
+        local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
+        remote_messages_by_destination: Dict[str, JsonDict],
     ) -> int:
         """Used to send messages from this server.
 
@@ -624,7 +645,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
         assert self._can_write_to_device
 
-        def add_messages_txn(txn, now_ms, stream_id):
+        def add_messages_txn(
+            txn: LoggingTransaction, now_ms: int, stream_id: int
+        ) -> None:
             # Add the local messages directly to the local inbox.
             self._add_messages_to_local_device_inbox_txn(
                 txn, stream_id, local_messages_by_user_then_device
@@ -677,11 +700,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         return self._device_inbox_id_gen.get_current_token()
 
     async def add_messages_from_remote_to_device_inbox(
-        self, origin: str, message_id: str, local_messages_by_user_then_device: dict
+        self,
+        origin: str,
+        message_id: str,
+        local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
     ) -> int:
         assert self._can_write_to_device
 
-        def add_messages_txn(txn, now_ms, stream_id):
+        def add_messages_txn(
+            txn: LoggingTransaction, now_ms: int, stream_id: int
+        ) -> None:
             # Check if we've already inserted a matching message_id for that
             # origin. This can happen if the origin doesn't receive our
             # acknowledgement from the first time we received the message.
@@ -727,8 +755,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         return stream_id
 
     def _add_messages_to_local_device_inbox_txn(
-        self, txn, stream_id, messages_by_user_then_device
-    ):
+        self,
+        txn: LoggingTransaction,
+        stream_id: int,
+        messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
+    ) -> None:
         assert self._can_write_to_device
 
         local_by_user_then_device = {}
@@ -840,8 +871,10 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
             self._remove_dead_devices_from_device_inbox,
         )
 
-    async def _background_drop_index_device_inbox(self, progress, batch_size):
-        def reindex_txn(conn):
+    async def _background_drop_index_device_inbox(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        def reindex_txn(conn: LoggingDatabaseConnection) -> None:
             txn = conn.cursor()
             txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
             txn.close()