diff options
-rw-r--r-- | changelog.d/13025.misc | 1 | ||||
-rw-r--r-- | mypy.ini | 1 | ||||
-rw-r--r-- | synapse/replication/slave/storage/devices.py | 3 | ||||
-rw-r--r-- | synapse/storage/databases/main/__init__.py | 1 | ||||
-rw-r--r-- | synapse/storage/databases/main/devices.py | 51 |
5 files changed, 36 insertions, 21 deletions
diff --git a/changelog.d/13025.misc b/changelog.d/13025.misc new file mode 100644 index 0000000000..7cb0d174b7 --- /dev/null +++ b/changelog.d/13025.misc @@ -0,0 +1 @@ +Add type annotations to `synapse.storage.databases.main.devices`. diff --git a/mypy.ini b/mypy.ini index 7973f2ac01..c5130feaec 100644 --- a/mypy.ini +++ b/mypy.ini @@ -27,7 +27,6 @@ exclude = (?x) ^( |synapse/storage/databases/__init__.py |synapse/storage/databases/main/cache.py - |synapse/storage/databases/main/devices.py |synapse/storage/schema/ |tests/api/test_auth.py diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 30717c2bd0..a48cc02069 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -19,13 +19,12 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.devices import DeviceWorkerStore -from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore if TYPE_CHECKING: from synapse.server import HomeServer -class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): +class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore): def __init__( self, database: DatabasePool, diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index cb3d1242bb..57aaf778ec 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -195,6 +195,7 @@ class DataStore( self._min_stream_order_on_start = self.get_room_min_stream_ordering() def get_device_stream_token(self) -> int: + # TODO: shouldn't this be moved to `DeviceWorkerStore`? return self._device_list_id_gen.get_current_token() async def get_users(self) -> List[JsonDict]: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 2414a7dc38..03d1334e03 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -28,6 +28,8 @@ from typing import ( cast, ) +from typing_extensions import Literal + from synapse.api.constants import EduTypes from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -44,6 +46,8 @@ from synapse.storage.database import ( LoggingTransaction, make_tuple_comparison_clause, ) +from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore +from synapse.storage.types import Cursor from synapse.types import JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -65,7 +69,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" -class DeviceWorkerStore(SQLBaseStore): +class DeviceWorkerStore(EndToEndKeyWorkerStore): def __init__( self, database: DatabasePool, @@ -74,7 +78,9 @@ class DeviceWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - device_list_max = self._device_list_id_gen.get_current_token() + # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a + # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker). + device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined] device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict( db_conn, "device_lists_stream", @@ -339,8 +345,9 @@ class DeviceWorkerStore(SQLBaseStore): # following this stream later. last_processed_stream_id = from_stream_id - query_map = {} - cross_signing_keys_by_user = {} + # A map of (user ID, device ID) to (stream ID, context). + query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]] = {} + cross_signing_keys_by_user: Dict[str, Dict[str, object]] = {} for user_id, device_id, update_stream_id, update_context in updates: # Calculate the remaining length budget. # Note that, for now, each entry in `cross_signing_keys_by_user` @@ -596,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore): txn=txn, table="device_lists_outbound_last_success", key_names=("destination", "user_id"), - key_values=((destination, user_id) for user_id, _ in rows), + key_values=[(destination, user_id) for user_id, _ in rows], value_names=("stream_id",), value_values=((stream_id,) for _, stream_id in rows), ) @@ -621,7 +628,9 @@ class DeviceWorkerStore(SQLBaseStore): The new stream ID. """ - async with self._device_list_id_gen.get_next() as stream_id: + # TODO: this looks like it's _writing_. Should this be on DeviceStore rather + # than DeviceWorkerStore? + async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined] await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -686,7 +695,7 @@ class DeviceWorkerStore(SQLBaseStore): } - users_needing_resync user_ids_not_in_cache = user_ids - user_ids_in_cache - results = {} + results: Dict[str, Dict[str, JsonDict]] = {} for user_id, device_id in query_list: if user_id not in user_ids_in_cache: continue @@ -727,7 +736,7 @@ class DeviceWorkerStore(SQLBaseStore): def get_cached_device_list_changes( self, from_key: int, - ) -> Optional[Set[str]]: + ) -> Optional[List[str]]: """Get set of users whose devices have changed since `from_key`, or None if that information is not in our cache. """ @@ -737,7 +746,7 @@ class DeviceWorkerStore(SQLBaseStore): async def get_users_whose_devices_changed( self, from_key: int, - user_ids: Optional[Iterable[str]] = None, + user_ids: Optional[Collection[str]] = None, to_key: Optional[int] = None, ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that @@ -757,6 +766,7 @@ class DeviceWorkerStore(SQLBaseStore): """ # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. + user_ids_to_check: Optional[Collection[str]] if user_ids is None: # Get set of all users that have had device list changes since 'from_key' user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed( @@ -772,7 +782,7 @@ class DeviceWorkerStore(SQLBaseStore): return set() def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: - changes = set() + changes: Set[str] = set() stream_id_where_clause = "stream_id > ?" sql_args = [from_key] @@ -788,6 +798,9 @@ class DeviceWorkerStore(SQLBaseStore): """ # Query device changes with a batch of users at a time + # Assertion for mypy's benefit; see also + # https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions + assert user_ids_to_check is not None for chunk in batch_iter(user_ids_to_check, 100): clause, args = make_in_list_sql_clause( txn.database_engine, "user_id", chunk @@ -854,7 +867,9 @@ class DeviceWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def _get_all_device_list_changes_for_remotes(txn): + def _get_all_device_list_changes_for_remotes( + txn: Cursor, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: # This query Does The Right Thing where it'll correctly apply the # bounds to the inner queries. sql = """ @@ -913,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore): desc="get_device_list_last_stream_id_for_remotes", ) - results = {user_id: None for user_id in user_ids} + results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids} results.update({row["user_id"]: row["stream_id"] for row in rows}) return results @@ -1337,9 +1352,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. - self.device_id_exists_cache = LruCache( - cache_name="device_id_exists", max_size=10000 - ) + self.device_id_exists_cache: LruCache[ + Tuple[str, str], Literal[True] + ] = LruCache(cache_name="device_id_exists", max_size=10000) async def store_device( self, @@ -1651,7 +1666,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context, ) - async with self._device_list_id_gen.get_next_mult( + async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined] len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( @@ -1704,7 +1719,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_ids: Iterable[str], hosts: Collection[str], stream_ids: List[int], - context: Dict[str, str], + context: Optional[Dict[str, str]], ) -> None: for host in hosts: txn.call_after( @@ -1875,7 +1890,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): [], ) - async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: + async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined] return await self.db_pool.runInteraction( "add_device_list_outbound_pokes", add_device_list_outbound_pokes_txn, |