summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/11549.misc1
-rw-r--r--mypy.ini4
-rw-r--r--synapse/handlers/e2e_room_keys.py15
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py237
-rw-r--r--synapse/types.py6
-rw-r--r--tests/storage/test_e2e_room_keys.py4
6 files changed, 188 insertions, 79 deletions
diff --git a/changelog.d/11549.misc b/changelog.d/11549.misc
new file mode 100644
index 0000000000..d451940bf2
--- /dev/null
+++ b/changelog.d/11549.misc
@@ -0,0 +1 @@
+Add missing type hints to storage classes.
diff --git a/mypy.ini b/mypy.ini
index 1867322044..e38ad635aa 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -27,7 +27,6 @@ exclude = (?x)
    |synapse/storage/databases/main/__init__.py
    |synapse/storage/databases/main/cache.py
    |synapse/storage/databases/main/devices.py
-   |synapse/storage/databases/main/e2e_room_keys.py
    |synapse/storage/databases/main/event_federation.py
    |synapse/storage/databases/main/event_push_actions.py
    |synapse/storage/databases/main/events_bg_updates.py
@@ -197,6 +196,9 @@ disallow_untyped_defs = True
 [mypy-synapse.storage.databases.main.directory]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.databases.main.e2e_room_keys]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.databases.main.end_to_end_keys]
 disallow_untyped_defs = True
 
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 31742236a9..12614b2c5d 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -14,7 +14,9 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, List, Optional
+from typing import TYPE_CHECKING, Dict, Optional
+
+from typing_extensions import Literal
 
 from synapse.api.errors import (
     Codes,
@@ -24,6 +26,7 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.logging.opentracing import log_kv, trace
+from synapse.storage.databases.main.e2e_room_keys import RoomKey
 from synapse.types import JsonDict
 from synapse.util.async_helpers import Linearizer
 
@@ -58,7 +61,9 @@ class E2eRoomKeysHandler:
         version: str,
         room_id: Optional[str] = None,
         session_id: Optional[str] = None,
-    ) -> List[JsonDict]:
+    ) -> Dict[
+        Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
+    ]:
         """Bulk get the E2E room keys for a given backup, optionally filtered to a given
         room, or a given session.
         See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
@@ -72,8 +77,8 @@ class E2eRoomKeysHandler:
         Raises:
             NotFoundError: if the backup version does not exist
         Returns:
-            A list of dicts giving the session_data and message metadata for
-            these room keys.
+            A dict giving the session_data and message metadata for these room keys.
+            `{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
         """
 
         # we deliberately take the lock to get keys so that changing the version
@@ -273,7 +278,7 @@ class E2eRoomKeysHandler:
 
     @staticmethod
     def _should_replace_room_key(
-        current_room_key: Optional[JsonDict], room_key: JsonDict
+        current_room_key: Optional[RoomKey], room_key: RoomKey
     ) -> bool:
         """
         Determine whether to replace a given current_room_key (if any)
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index b15fb71e62..0cb48b9dd7 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -13,35 +13,71 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Optional
+from typing import Dict, Iterable, Mapping, Optional, Tuple, cast
+
+from typing_extensions import Literal, TypedDict
 
 from synapse.api.errors import StoreError
 from synapse.logging.opentracing import log_kv, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
+from synapse.types import JsonDict, JsonSerializable
 from synapse.util import json_encoder
 
 
+class RoomKey(TypedDict):
+    """`KeyBackupData` in the Matrix spec.
+
+    https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid
+    """
+
+    first_message_index: int
+    forwarded_count: int
+    is_verified: bool
+    session_data: JsonSerializable
+
+
 class EndToEndRoomKeyStore(SQLBaseStore):
+    """The store for end to end room key backups.
+
+    See https://spec.matrix.org/v1.1/client-server-api/#server-side-key-backups
+
+    As per the spec, backups are identified by an opaque version string. Internally,
+    version identifiers are assigned using incrementing integers. Non-numeric version
+    strings are treated as if they do not exist, since we would have never issued them.
+    """
+
     async def update_e2e_room_key(
-        self, user_id, version, room_id, session_id, room_key
-    ):
+        self,
+        user_id: str,
+        version: str,
+        room_id: str,
+        session_id: str,
+        room_key: RoomKey,
+    ) -> None:
         """Replaces the encrypted E2E room key for a given session in a given backup
 
         Args:
-            user_id(str): the user whose backup we're setting
-            version(str): the version ID of the backup we're updating
-            room_id(str): the ID of the room whose keys we're setting
-            session_id(str): the session whose room_key we're setting
-            room_key(dict): the room_key being set
+            user_id: the user whose backup we're setting
+            version: the version ID of the backup we're updating
+            room_id: the ID of the room whose keys we're setting
+            session_id: the session whose room_key we're setting
+            room_key: the room_key being set
         Raises:
             StoreError
         """
+        try:
+            version_int = int(version)
+        except ValueError:
+            # Our versions are all ints so if we can't convert it to an integer,
+            # it doesn't exist.
+            raise StoreError(404, "No backup with that version exists")
 
         await self.db_pool.simple_update_one(
             table="e2e_room_keys",
             keyvalues={
                 "user_id": user_id,
-                "version": version,
+                "version": version_int,
                 "room_id": room_id,
                 "session_id": session_id,
             },
@@ -54,22 +90,29 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             desc="update_e2e_room_key",
         )
 
-    async def add_e2e_room_keys(self, user_id, version, room_keys):
+    async def add_e2e_room_keys(
+        self, user_id: str, version: str, room_keys: Iterable[Tuple[str, str, RoomKey]]
+    ) -> None:
         """Bulk add room keys to a given backup.
 
         Args:
-            user_id (str): the user whose backup we're adding to
-            version (str): the version ID of the backup for the set of keys we're adding to
-            room_keys (iterable[(str, str, dict)]): the keys to add, in the form
-                (roomID, sessionID, keyData)
+            user_id: the user whose backup we're adding to
+            version: the version ID of the backup for the set of keys we're adding to
+            room_keys: the keys to add, in the form (roomID, sessionID, keyData)
         """
+        try:
+            version_int = int(version)
+        except ValueError:
+            # Our versions are all ints so if we can't convert it to an integer,
+            # it doesn't exist.
+            raise StoreError(404, "No backup with that version exists")
 
         values = []
         for (room_id, session_id, room_key) in room_keys:
             values.append(
                 {
                     "user_id": user_id,
-                    "version": version,
+                    "version": version_int,
                     "room_id": room_id,
                     "session_id": session_id,
                     "first_message_index": room_key["first_message_index"],
@@ -92,31 +135,39 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         )
 
     @trace
-    async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
+    async def get_e2e_room_keys(
+        self,
+        user_id: str,
+        version: str,
+        room_id: Optional[str] = None,
+        session_id: Optional[str] = None,
+    ) -> Dict[
+        Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
+    ]:
         """Bulk get the E2E room keys for a given backup, optionally filtered to a given
         room, or a given session.
 
         Args:
-            user_id (str): the user whose backup we're querying
-            version (str): the version ID of the backup for the set of keys we're querying
-            room_id (str): Optional. the ID of the room whose keys we're querying, if any.
+            user_id: the user whose backup we're querying
+            version: the version ID of the backup for the set of keys we're querying
+            room_id: Optional. the ID of the room whose keys we're querying, if any.
                 If not specified, we return the keys for all the rooms in the backup.
-            session_id (str): Optional. the session whose room_key we're querying, if any.
+            session_id: Optional. the session whose room_key we're querying, if any.
                 If specified, we also require the room_id to be specified.
                 If not specified, we return all the keys in this version of
                 the backup (or for the specified room)
 
         Returns:
-            A list of dicts giving the session_data and message metadata for
-            these room keys.
+            A dict giving the session_data and message metadata for these room keys.
+            `{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
         """
 
         try:
-            version = int(version)
+            version_int = int(version)
         except ValueError:
             return {"rooms": {}}
 
-        keyvalues = {"user_id": user_id, "version": version}
+        keyvalues = {"user_id": user_id, "version": version_int}
         if room_id:
             keyvalues["room_id"] = room_id
             if session_id:
@@ -137,7 +188,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             desc="get_e2e_room_keys",
         )
 
-        sessions = {"rooms": {}}
+        sessions: Dict[
+            Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
+        ] = {"rooms": {}}
         for row in rows:
             room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
             room_entry["sessions"][row["session_id"]] = {
@@ -150,7 +203,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
         return sessions
 
-    async def get_e2e_room_keys_multi(self, user_id, version, room_keys):
+    async def get_e2e_room_keys_multi(
+        self,
+        user_id: str,
+        version: str,
+        room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
+    ) -> Dict[str, Dict[str, RoomKey]]:
         """Get multiple room keys at a time.  The difference between this function and
         get_e2e_room_keys is that this function can be used to retrieve
         multiple specific keys at a time, whereas get_e2e_room_keys is used for
@@ -158,26 +216,36 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         specific key.
 
         Args:
-            user_id (str): the user whose backup we're querying
-            version (str): the version ID of the backup we're querying about
-            room_keys (dict[str, dict[str, iterable[str]]]): a map from
-                room ID -> {"session": [session ids]} indicating the session IDs
-                that we want to query
+            user_id: the user whose backup we're querying
+            version: the version ID of the backup we're querying about
+            room_keys: a map from room ID -> {"sessions": [session ids]}
+                indicating the session IDs that we want to query
 
         Returns:
-           dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key
+           A map of room IDs to session IDs to room key
         """
+        try:
+            version_int = int(version)
+        except ValueError:
+            # Our versions are all ints so if we can't convert it to an integer,
+            # it doesn't exist.
+            return {}
 
         return await self.db_pool.runInteraction(
             "get_e2e_room_keys_multi",
             self._get_e2e_room_keys_multi_txn,
             user_id,
-            version,
+            version_int,
             room_keys,
         )
 
     @staticmethod
-    def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys):
+    def _get_e2e_room_keys_multi_txn(
+        txn: LoggingTransaction,
+        user_id: str,
+        version: int,
+        room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
+    ) -> Dict[str, Dict[str, RoomKey]]:
         if not room_keys:
             return {}
 
@@ -209,7 +277,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
         txn.execute(sql, params)
 
-        ret = {}
+        ret: Dict[str, Dict[str, RoomKey]] = {}
 
         for row in txn:
             room_id = row[0]
@@ -231,36 +299,49 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             user_id: the user whose backup we're querying
             version: the version ID of the backup we're querying about
         """
+        try:
+            version_int = int(version)
+        except ValueError:
+            # Our versions are all ints so if we can't convert it to an integer,
+            # it doesn't exist.
+            return 0
 
         return await self.db_pool.simple_select_one_onecol(
             table="e2e_room_keys",
-            keyvalues={"user_id": user_id, "version": version},
+            keyvalues={"user_id": user_id, "version": version_int},
             retcol="COUNT(*)",
             desc="count_e2e_room_keys",
         )
 
     @trace
     async def delete_e2e_room_keys(
-        self, user_id, version, room_id=None, session_id=None
-    ):
+        self,
+        user_id: str,
+        version: str,
+        room_id: Optional[str] = None,
+        session_id: Optional[str] = None,
+    ) -> None:
         """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
         room or a given session.
 
         Args:
-            user_id(str): the user whose backup we're deleting from
-            version(str): the version ID of the backup for the set of keys we're deleting
-            room_id(str): Optional. the ID of the room whose keys we're deleting, if any.
+            user_id: the user whose backup we're deleting from
+            version: the version ID of the backup for the set of keys we're deleting
+            room_id: Optional. the ID of the room whose keys we're deleting, if any.
                 If not specified, we delete the keys for all the rooms in the backup.
-            session_id(str): Optional. the session whose room_key we're querying, if any.
+            session_id: Optional. the session whose room_key we're querying, if any.
                 If specified, we also require the room_id to be specified.
                 If not specified, we delete all the keys in this version of
                 the backup (or for the specified room)
-
-        Returns:
-            The deletion transaction
         """
+        try:
+            version_int = int(version)
+        except ValueError:
+            # Our versions are all ints so if we can't convert it to an integer,
+            # it doesn't exist.
+            return
 
-        keyvalues = {"user_id": user_id, "version": int(version)}
+        keyvalues = {"user_id": user_id, "version": version_int}
         if room_id:
             keyvalues["room_id"] = room_id
             if session_id:
@@ -271,23 +352,27 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         )
 
     @staticmethod
-    def _get_current_version(txn, user_id):
+    def _get_current_version(txn: LoggingTransaction, user_id: str) -> int:
         txn.execute(
             "SELECT MAX(version) FROM e2e_room_keys_versions "
             "WHERE user_id=? AND deleted=0",
             (user_id,),
         )
-        row = txn.fetchone()
-        if not row:
+        # `SELECT MAX() FROM ...` will always return 1 row. The value in that row will
+        # be `NULL` when there are no available versions.
+        row = cast(Tuple[Optional[int]], txn.fetchone())
+        if row[0] is None:
             raise StoreError(404, "No current backup version")
         return row[0]
 
-    async def get_e2e_room_keys_version_info(self, user_id, version=None):
+    async def get_e2e_room_keys_version_info(
+        self, user_id: str, version: Optional[str] = None
+    ) -> JsonDict:
         """Get info metadata about a version of our room_keys backup.
 
         Args:
-            user_id(str): the user whose backup we're querying
-            version(str): Optional. the version ID of the backup we're querying about
+            user_id: the user whose backup we're querying
+            version: Optional. the version ID of the backup we're querying about
                 If missing, we return the information about the current version.
         Raises:
             StoreError: with code 404 if there are no e2e_room_keys_versions present
@@ -300,7 +385,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 etag(int): tag of the keys in the backup
         """
 
-        def _get_e2e_room_keys_version_info_txn(txn):
+        def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict:
             if version is None:
                 this_version = self._get_current_version(txn, user_id)
             else:
@@ -309,14 +394,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 except ValueError:
                     # Our versions are all ints so if we can't convert it to an integer,
                     # it isn't there.
-                    raise StoreError(404, "No row found")
+                    raise StoreError(404, "No backup with that version exists")
 
             result = self.db_pool.simple_select_one_txn(
                 txn,
                 table="e2e_room_keys_versions",
                 keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
                 retcols=("version", "algorithm", "auth_data", "etag"),
+                allow_none=False,
             )
+            assert result is not None  # see comment on `simple_select_one_txn`
             result["auth_data"] = db_to_json(result["auth_data"])
             result["version"] = str(result["version"])
             if result["etag"] is None:
@@ -328,28 +415,28 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         )
 
     @trace
-    async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str:
+    async def create_e2e_room_keys_version(self, user_id: str, info: JsonDict) -> str:
         """Atomically creates a new version of this user's e2e_room_keys store
         with the given version info.
 
         Args:
-            user_id(str): the user whose backup we're creating a version
-            info(dict): the info about the backup version to be created
+            user_id: the user whose backup we're creating a version
+            info: the info about the backup version to be created
 
         Returns:
             The newly created version ID
         """
 
-        def _create_e2e_room_keys_version_txn(txn):
+        def _create_e2e_room_keys_version_txn(txn: LoggingTransaction) -> str:
             txn.execute(
                 "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?",
                 (user_id,),
             )
-            current_version = txn.fetchone()[0]
+            current_version = cast(Tuple[Optional[int]], txn.fetchone())[0]
             if current_version is None:
-                current_version = "0"
+                current_version = 0
 
-            new_version = str(int(current_version) + 1)
+            new_version = current_version + 1
 
             self.db_pool.simple_insert_txn(
                 txn,
@@ -362,7 +449,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 },
             )
 
-            return new_version
+            return str(new_version)
 
         return await self.db_pool.runInteraction(
             "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
@@ -373,7 +460,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         self,
         user_id: str,
         version: str,
-        info: Optional[dict] = None,
+        info: Optional[JsonDict] = None,
         version_etag: Optional[int] = None,
     ) -> None:
         """Update a given backup version
@@ -386,7 +473,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             version_etag: etag of the keys in the backup. If None, then the etag
                 is not updated.
         """
-        updatevalues = {}
+        updatevalues: Dict[str, object] = {}
 
         if info is not None and "auth_data" in info:
             updatevalues["auth_data"] = json_encoder.encode(info["auth_data"])
@@ -394,9 +481,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             updatevalues["etag"] = version_etag
 
         if updatevalues:
-            await self.db_pool.simple_update(
+            try:
+                version_int = int(version)
+            except ValueError:
+                # Our versions are all ints so if we can't convert it to an integer,
+                # it doesn't exist.
+                raise StoreError(404, "No backup with that version exists")
+
+            await self.db_pool.simple_update_one(
                 table="e2e_room_keys_versions",
-                keyvalues={"user_id": user_id, "version": version},
+                keyvalues={"user_id": user_id, "version": version_int},
                 updatevalues=updatevalues,
                 desc="update_e2e_room_keys_version",
             )
@@ -417,13 +511,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 or if the version requested doesn't exist.
         """
 
-        def _delete_e2e_room_keys_version_txn(txn):
+        def _delete_e2e_room_keys_version_txn(txn: LoggingTransaction) -> None:
             if version is None:
                 this_version = self._get_current_version(txn, user_id)
-                if this_version is None:
-                    raise StoreError(404, "No current backup version")
             else:
-                this_version = version
+                try:
+                    this_version = int(version)
+                except ValueError:
+                    # Our versions are all ints so if we can't convert it to an integer,
+                    # it isn't there.
+                    raise StoreError(404, "No backup with that version exists")
 
             self.db_pool.simple_delete_txn(
                 txn,
diff --git a/synapse/types.py b/synapse/types.py
index fb72f19343..b06979e8e8 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -59,9 +59,11 @@ StateKey = Tuple[str, str]
 StateMap = Mapping[StateKey, T]
 MutableStateMap = MutableMapping[StateKey, T]
 
-# the type of a JSON-serialisable dict. This could be made stronger, but it will
-# do for now.
+# JSON types. These could be made stronger, but will do for now.
+# A JSON-serialisable dict.
 JsonDict = Dict[str, Any]
+# A JSON-serialisable object.
+JsonSerializable = object
 
 
 # Note that this seems to require inheriting *directly* from Interface in order
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index 9b6b425425..7556171d8a 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -12,10 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.storage.databases.main.e2e_room_keys import RoomKey
+
 from tests import unittest
 
 # sample room_key data for use in the tests
-room_key = {
+room_key: RoomKey = {
     "first_message_index": 1,
     "forwarded_count": 1,
     "is_verified": False,