summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorSean Quah <8349537+squahtx@users.noreply.github.com>2021-12-13 16:28:26 +0000
committerGitHub <noreply@github.com>2021-12-13 16:28:26 +0000
commit1abfb15f07d4f8119afcf908f9e1903e7feef371 (patch)
treeeef0ee44bb012e486fc691e24c1e6fe552771fcf /synapse
parentAdd type hints to `synapse/storage/databases/main/account_data.py` (#11546) (diff)
downloadsynapse-1abfb15f07d4f8119afcf908f9e1903e7feef371.tar.xz
Add type hints to `synapse/storage/databases/main/end_to_end_keys.py` (#11551)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/databases/main/__init__.py3
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py211
2 files changed, 146 insertions, 68 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 9ff2d8d8c3..065145c0d2 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -143,9 +143,6 @@ class DataStore(
                 ("device_lists_outbound_pokes", "stream_id"),
             ],
         )
-        self._cross_signing_id_gen = StreamIdGenerator(
-            db_conn, "e2e_cross_signing_keys", "stream_id"
-        )
 
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b06c1dc45b..57b5ffbad3 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,19 +14,32 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import abc
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    cast,
+)
 
 import attr
 from canonicaljson import encode_canonical_json
 
-from twisted.enterprise.adbapi import Connection
-
 from synapse.api.constants import DeviceKeyAlgorithms
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+    make_in_list_sql_clause,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -50,7 +63,12 @@ class DeviceKeyLookupResult:
 
 
 class EndToEndKeyBackgroundStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -62,8 +80,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore):
         )
 
 
-class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._allow_device_name_lookup_over_federation = (
@@ -124,7 +147,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
         # Build the result structure, un-jsonify the results, and add the
         # "unsigned" section
-        rv = {}
+        rv: Dict[str, Dict[str, JsonDict]] = {}
         for user_id, device_keys in results.items():
             rv[user_id] = {}
             for device_id, device_info in device_keys.items():
@@ -195,6 +218,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             # add each cross-signing signature to the correct device in the result dict.
             for (user_id, key_id, device_id, signature) in cross_sigs_result:
                 target_device_result = result[user_id][device_id]
+                # We've only looked up cross-signatures for non-deleted devices with key
+                # data.
+                assert target_device_result is not None
+                assert target_device_result.keys is not None
                 target_device_signatures = target_device_result.keys.setdefault(
                     "signatures", {}
                 )
@@ -207,7 +234,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         return result
 
     def _get_e2e_device_keys_txn(
-        self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+        self,
+        txn: LoggingTransaction,
+        query_list: Collection[Tuple[str, str]],
+        include_all_devices: bool = False,
+        include_deleted_devices: bool = False,
     ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
         """Get information on devices from the database
 
@@ -263,7 +294,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         return result
 
     def _get_e2e_cross_signing_signatures_for_devices_txn(
-        self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
+        self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
     ) -> List[Tuple[str, str, str, str]]:
         """Get cross-signing signatures for a given list of devices
 
@@ -289,7 +320,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         )
 
         txn.execute(signature_sql, signature_query_params)
-        return txn.fetchall()
+        return cast(
+            List[
+                Tuple[
+                    str,
+                    str,
+                    str,
+                    str,
+                ]
+            ],
+            txn.fetchall(),
+        )
 
     async def get_e2e_one_time_keys(
         self, user_id: str, device_id: str, key_ids: List[str]
@@ -335,7 +376,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
         """
 
-        def _add_e2e_one_time_keys(txn):
+        def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
             set_tag("user_id", user_id)
             set_tag("device_id", device_id)
             set_tag("new_keys", new_keys)
@@ -375,7 +416,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             A mapping from algorithm to number of keys for that algorithm.
         """
 
-        def _count_e2e_one_time_keys(txn):
+        def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]:
             sql = (
                 "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
                 " WHERE user_id = ? AND device_id = ?"
@@ -421,7 +462,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         )
 
     def _set_e2e_fallback_keys_txn(
-        self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_id: str,
+        fallback_keys: JsonDict,
     ) -> None:
         # fallback_keys will usually only have one item in it, so using a for
         # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
@@ -483,7 +528,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
     async def get_e2e_cross_signing_key(
         self, user_id: str, key_type: str, from_user_id: Optional[str] = None
-    ) -> Optional[dict]:
+    ) -> Optional[JsonDict]:
         """Returns a user's cross-signing key.
 
         Args:
@@ -504,7 +549,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         return user_keys.get(key_type)
 
     @cached(num_args=1)
-    def _get_bare_e2e_cross_signing_keys(self, user_id):
+    def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
         """Dummy function.  Only used to make a cache for
         _get_bare_e2e_cross_signing_keys_bulk.
         """
@@ -517,7 +562,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
     )
     async def _get_bare_e2e_cross_signing_keys_bulk(
         self, user_ids: Iterable[str]
-    ) -> Dict[str, Dict[str, dict]]:
+    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
         the signatures for the calling user need to be fetched.
@@ -531,32 +576,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             their user ID will map to None.
 
         """
-        return await self.db_pool.runInteraction(
+        result = await self.db_pool.runInteraction(
             "get_bare_e2e_cross_signing_keys_bulk",
             self._get_bare_e2e_cross_signing_keys_bulk_txn,
             user_ids,
         )
 
+        # The `Optional` comes from the `@cachedList` decorator.
+        return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
+
     def _get_bare_e2e_cross_signing_keys_bulk_txn(
         self,
-        txn: Connection,
+        txn: LoggingTransaction,
         user_ids: Iterable[str],
-    ) -> Dict[str, Dict[str, dict]]:
+    ) -> Dict[str, Dict[str, JsonDict]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
         the signatures for the calling user need to be fetched.
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            user_ids (list[str]): the users whose keys are being requested
+            txn: db connection
+            user_ids: the users whose keys are being requested
 
         Returns:
-            dict[str, dict[str, dict]]: mapping from user ID to key type to key
-                data.  If a user's cross-signing keys were not found, their user
-                ID will not be in the dict.
+            Mapping from user ID to key type to key data.
+            If a user's cross-signing keys were not found, their user ID will not be in
+            the dict.
 
         """
-        result = {}
+        result: Dict[str, Dict[str, JsonDict]] = {}
 
         for user_chunk in batch_iter(user_ids, 100):
             clause, params = make_in_list_sql_clause(
@@ -596,43 +644,48 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
                 user_id = row["user_id"]
                 key_type = row["keytype"]
                 key = db_to_json(row["keydata"])
-                user_info = result.setdefault(user_id, {})
-                user_info[key_type] = key
+                user_keys = result.setdefault(user_id, {})
+                user_keys[key_type] = key
 
         return result
 
     def _get_e2e_cross_signing_signatures_txn(
         self,
-        txn: Connection,
-        keys: Dict[str, Dict[str, dict]],
+        txn: LoggingTransaction,
+        keys: Dict[str, Optional[Dict[str, JsonDict]]],
         from_user_id: str,
-    ) -> Dict[str, Dict[str, dict]]:
+    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
         """Returns the cross-signing signatures made by a user on a set of keys.
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            keys (dict[str, dict[str, dict]]): a map of user ID to key type to
-                key data.  This dict will be modified to add signatures.
-            from_user_id (str): fetch the signatures made by this user
+            txn: db connection
+            keys: a map of user ID to key type to key data.
+                This dict will be modified to add signatures.
+            from_user_id: fetch the signatures made by this user
 
         Returns:
-            dict[str, dict[str, dict]]: mapping from user ID to key type to key
-                data.  The return value will be the same as the keys argument,
-                with the modifications included.
+            Mapping from user ID to key type to key data.
+            The return value will be the same as the keys argument, with the
+            modifications included.
         """
 
         # find out what cross-signing keys (a.k.a. devices) we need to get
         # signatures for.  This is a map of (user_id, device_id) to key type
         # (device_id is the key's public part).
-        devices = {}
+        devices: Dict[Tuple[str, str], str] = {}
 
-        for user_id, user_info in keys.items():
-            if user_info is None:
+        for user_id, user_keys in keys.items():
+            if user_keys is None:
                 continue
-            for key_type, key in user_info.items():
+            for key_type, key in user_keys.items():
                 device_id = None
                 for k in key["keys"].values():
                     device_id = k
+                # `key` ought to be a `CrossSigningKey`, whose .keys property is a
+                # dictionary with a single entry:
+                #     "algorithm:base64_public_key": "base64_public_key"
+                # See https://spec.matrix.org/v1.1/client-server-api/#cross-signing
+                assert isinstance(device_id, str)
                 devices[(user_id, device_id)] = key_type
 
         for batch in batch_iter(devices.keys(), size=100):
@@ -656,15 +709,20 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
             # and add the signatures to the appropriate keys
             for row in rows:
-                key_id = row["key_id"]
-                target_user_id = row["target_user_id"]
-                target_device_id = row["target_device_id"]
+                key_id: str = row["key_id"]
+                target_user_id: str = row["target_user_id"]
+                target_device_id: str = row["target_device_id"]
                 key_type = devices[(target_user_id, target_device_id)]
                 # We need to copy everything, because the result may have come
                 # from the cache.  dict.copy only does a shallow copy, so we
                 # need to recursively copy the dicts that will be modified.
-                user_info = keys[target_user_id] = keys[target_user_id].copy()
-                target_user_key = user_info[key_type] = user_info[key_type].copy()
+                user_keys = keys[target_user_id]
+                # `user_keys` cannot be `None` because we only fetched signatures for
+                # users with keys
+                assert user_keys is not None
+                user_keys = keys[target_user_id] = user_keys.copy()
+
+                target_user_key = user_keys[key_type] = user_keys[key_type].copy()
                 if "signatures" in target_user_key:
                     signatures = target_user_key["signatures"] = target_user_key[
                         "signatures"
@@ -683,7 +741,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
     async def get_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str], from_user_id: Optional[str] = None
-    ) -> Dict[str, Optional[Dict[str, dict]]]:
+    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
         """Returns the cross-signing keys for a set of users.
 
         Args:
@@ -741,7 +799,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def _get_all_user_signature_changes_for_remotes_txn(txn):
+        def _get_all_user_signature_changes_for_remotes_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             sql = """
                 SELECT stream_id, from_user_id AS user_id
                 FROM user_signature_stream
@@ -785,7 +845,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
         @trace
         def _claim_e2e_one_time_key_simple(
-            txn, user_id: str, device_id: str, algorithm: str
+            txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
         ) -> Optional[Tuple[str, str]]:
             """Claim OTK for device for DBs that don't support RETURNING.
 
@@ -825,7 +885,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
         @trace
         def _claim_e2e_one_time_key_returning(
-            txn, user_id: str, device_id: str, algorithm: str
+            txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
         ) -> Optional[Tuple[str, str]]:
             """Claim OTK for device for DBs that support RETURNING.
 
@@ -860,7 +920,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             key_id, key_json = otk_row
             return f"{algorithm}:{key_id}", key_json
 
-        results = {}
+        results: Dict[str, Dict[str, Dict[str, str]]] = {}
         for user_id, device_id, algorithm in query_list:
             if self.database_engine.supports_returning:
                 # If we support RETURNING clause we can use a single query that
@@ -930,6 +990,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
 
 class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        self._cross_signing_id_gen = StreamIdGenerator(
+            db_conn, "e2e_cross_signing_keys", "stream_id"
+        )
+
     async def set_e2e_device_keys(
         self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
     ) -> bool:
@@ -937,7 +1009,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
         or the keys were already in the database.
         """
 
-        def _set_e2e_device_keys_txn(txn):
+        def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
             set_tag("user_id", user_id)
             set_tag("device_id", device_id)
             set_tag("time_now", time_now)
@@ -973,7 +1045,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
         )
 
     async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
-        def delete_e2e_keys_by_device_txn(txn):
+        def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
             log_kv(
                 {
                     "message": "Deleting keys for device",
@@ -1012,17 +1084,24 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
         )
 
-    def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
+    def _set_e2e_cross_signing_key_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        key_type: str,
+        key: JsonDict,
+        stream_id: int,
+    ) -> None:
         """Set a user's cross-signing key.
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            user_id (str): the user to set the signing key for
-            key_type (str): the type of key that is being set: either 'master'
+            txn: db connection
+            user_id: the user to set the signing key for
+            key_type: the type of key that is being set: either 'master'
                 for a master key, 'self_signing' for a self-signing key, or
                 'user_signing' for a user-signing key
-            key (dict): the key data
-            stream_id (int)
+            key: the key data
+            stream_id
         """
         # the 'key' dict will look something like:
         # {
@@ -1075,13 +1154,15 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
         )
 
-    async def set_e2e_cross_signing_key(self, user_id, key_type, key):
+    async def set_e2e_cross_signing_key(
+        self, user_id: str, key_type: str, key: JsonDict
+    ) -> None:
         """Set a user's cross-signing key.
 
         Args:
-            user_id (str): the user to set the user-signing key for
-            key_type (str): the type of cross-signing key to set
-            key (dict): the key data
+            user_id: the user to set the user-signing key for
+            key_type: the type of cross-signing key to set
+            key: the key data
         """
 
         async with self._cross_signing_id_gen.get_next() as stream_id: