diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index b4a1b041b1..599b418383 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,17 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ cast,
+)
from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -118,7 +128,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
prefilled_cache=device_outbox_prefill,
)
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[ToDeviceStream.ToDeviceStreamRow],
+ ) -> None:
if stream_name == ToDeviceStream.NAME:
# If replication is happening than postgres must be being used.
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
@@ -134,7 +150,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
return super().process_replication_rows(stream_name, instance_name, token, rows)
- def get_to_device_stream_token(self):
+ def get_to_device_stream_token(self) -> int:
return self._device_inbox_id_gen.get_current_token()
async def get_messages_for_user_devices(
@@ -301,7 +317,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
if not user_ids_to_query:
return {}, to_stream_id
- def get_device_messages_txn(txn: LoggingTransaction):
+ def get_device_messages_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
# Build a query to select messages from any of the given devices that
# are between the given stream id bounds.
@@ -428,7 +446,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
log_kv({"message": "No changes in cache since last check"})
return 0
- def delete_messages_for_device_txn(txn):
+ def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
sql = (
"DELETE FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
@@ -455,15 +473,14 @@ class DeviceInboxWorkerStore(SQLBaseStore):
@trace
async def get_new_device_msgs_for_remote(
- self, destination, last_stream_id, current_stream_id, limit
- ) -> Tuple[List[dict], int]:
+ self, destination: str, last_stream_id: int, current_stream_id: int, limit: int
+ ) -> Tuple[List[JsonDict], int]:
"""
Args:
- destination(str): The name of the remote server.
- last_stream_id(int|long): The last position of the device message stream
+ destination: The name of the remote server.
+ last_stream_id: The last position of the device message stream
that the server sent up to.
- current_stream_id(int|long): The current position of the device
- message stream.
+ current_stream_id: The current position of the device message stream.
Returns:
A list of messages for the device and where in the stream the messages got to.
"""
@@ -485,7 +502,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return [], last_stream_id
@trace
- def get_new_messages_for_remote_destination_txn(txn):
+ def get_new_messages_for_remote_destination_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], int]:
sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox"
" WHERE destination = ?"
@@ -527,7 +546,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
up_to_stream_id: Where to delete messages up to.
"""
- def delete_messages_for_remote_destination_txn(txn):
+ def delete_messages_for_remote_destination_txn(txn: LoggingTransaction) -> None:
sql = (
"DELETE FROM device_federation_outbox"
" WHERE destination = ?"
@@ -566,7 +585,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_new_device_messages_txn(txn):
+ def get_all_new_device_messages_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
@@ -607,8 +628,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
@trace
async def add_messages_to_device_inbox(
self,
- local_messages_by_user_then_device: dict,
- remote_messages_by_destination: dict,
+ local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
+ remote_messages_by_destination: Dict[str, JsonDict],
) -> int:
"""Used to send messages from this server.
@@ -624,7 +645,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
assert self._can_write_to_device
- def add_messages_txn(txn, now_ms, stream_id):
+ def add_messages_txn(
+ txn: LoggingTransaction, now_ms: int, stream_id: int
+ ) -> None:
# Add the local messages directly to the local inbox.
self._add_messages_to_local_device_inbox_txn(
txn, stream_id, local_messages_by_user_then_device
@@ -677,11 +700,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return self._device_inbox_id_gen.get_current_token()
async def add_messages_from_remote_to_device_inbox(
- self, origin: str, message_id: str, local_messages_by_user_then_device: dict
+ self,
+ origin: str,
+ message_id: str,
+ local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
) -> int:
assert self._can_write_to_device
- def add_messages_txn(txn, now_ms, stream_id):
+ def add_messages_txn(
+ txn: LoggingTransaction, now_ms: int, stream_id: int
+ ) -> None:
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
@@ -727,8 +755,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return stream_id
def _add_messages_to_local_device_inbox_txn(
- self, txn, stream_id, messages_by_user_then_device
- ):
+ self,
+ txn: LoggingTransaction,
+ stream_id: int,
+ messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
+ ) -> None:
assert self._can_write_to_device
local_by_user_then_device = {}
@@ -840,8 +871,10 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self._remove_dead_devices_from_device_inbox,
)
- async def _background_drop_index_device_inbox(self, progress, batch_size):
- def reindex_txn(conn):
+ async def _background_drop_index_device_inbox(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ def reindex_txn(conn: LoggingDatabaseConnection) -> None:
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
|