summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorSean Quah <8349537+squahtx@users.noreply.github.com>2021-12-14 17:46:47 +0000
committerGitHub <noreply@github.com>2021-12-14 17:46:47 +0000
commitecfcd9bbbed62f8ca5d42bd8ba02b696e9079573 (patch)
tree953d3c342445153c29e8c1a2413abac8c82658b0 /synapse/storage/databases
parentAdd missing type hints to `synapse.logging.context` (#11556) (diff)
downloadsynapse-ecfcd9bbbed62f8ca5d42bd8ba02b696e9079573.tar.xz
Add type hints to `synapse/storage/databases/main/e2e_room_keys.py` (#11549)
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py237
1 files changed, 167 insertions, 70 deletions
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index b15fb71e62..0cb48b9dd7 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -13,35 +13,71 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Optional
+from typing import Dict, Iterable, Mapping, Optional, Tuple, cast
+
+from typing_extensions import Literal, TypedDict
 
 from synapse.api.errors import StoreError
 from synapse.logging.opentracing import log_kv, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
+from synapse.types import JsonDict, JsonSerializable
 from synapse.util import json_encoder
 
 
+class RoomKey(TypedDict):
+    """`KeyBackupData` in the Matrix spec.
+
+    https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid
+    """
+
+    first_message_index: int
+    forwarded_count: int
+    is_verified: bool
+    session_data: JsonSerializable
+
+
 class EndToEndRoomKeyStore(SQLBaseStore):
+    """The store for end to end room key backups.
+
+    See https://spec.matrix.org/v1.1/client-server-api/#server-side-key-backups
+
+    As per the spec, backups are identified by an opaque version string. Internally,
+    version identifiers are assigned using incrementing integers. Non-numeric version
+    strings are treated as if they do not exist, since we would have never issued them.
+    """
+
     async def update_e2e_room_key(
-        self, user_id, version, room_id, session_id, room_key
-    ):
+        self,
+        user_id: str,
+        version: str,
+        room_id: str,
+        session_id: str,
+        room_key: RoomKey,
+    ) -> None:
         """Replaces the encrypted E2E room key for a given session in a given backup
 
         Args:
-            user_id(str): the user whose backup we're setting
-            version(str): the version ID of the backup we're updating
-            room_id(str): the ID of the room whose keys we're setting
-            session_id(str): the session whose room_key we're setting
-            room_key(dict): the room_key being set
+            user_id: the user whose backup we're setting
+            version: the version ID of the backup we're updating
+            room_id: the ID of the room whose keys we're setting
+            session_id: the session whose room_key we're setting
+            room_key: the room_key being set
         Raises:
             StoreError
         """
+        try:
+            version_int = int(version)
+        except ValueError:
+            # Our versions are all ints so if we can't convert it to an integer,
+            # it doesn't exist.
+            raise StoreError(404, "No backup with that version exists")
 
         await self.db_pool.simple_update_one(
             table="e2e_room_keys",
             keyvalues={
                 "user_id": user_id,
-                "version": version,
+                "version": version_int,
                 "room_id": room_id,
                 "session_id": session_id,
             },
@@ -54,22 +90,29 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             desc="update_e2e_room_key",
         )
 
-    async def add_e2e_room_keys(self, user_id, version, room_keys):
+    async def add_e2e_room_keys(
+        self, user_id: str, version: str, room_keys: Iterable[Tuple[str, str, RoomKey]]
+    ) -> None:
         """Bulk add room keys to a given backup.
 
         Args:
-            user_id (str): the user whose backup we're adding to
-            version (str): the version ID of the backup for the set of keys we're adding to
-            room_keys (iterable[(str, str, dict)]): the keys to add, in the form
-                (roomID, sessionID, keyData)
+            user_id: the user whose backup we're adding to
+            version: the version ID of the backup for the set of keys we're adding to
+            room_keys: the keys to add, in the form (roomID, sessionID, keyData)
         """
+        try:
+            version_int = int(version)
+        except ValueError:
+            # Our versions are all ints so if we can't convert it to an integer,
+            # it doesn't exist.
+            raise StoreError(404, "No backup with that version exists")
 
         values = []
         for (room_id, session_id, room_key) in room_keys:
             values.append(
                 {
                     "user_id": user_id,
-                    "version": version,
+                    "version": version_int,
                     "room_id": room_id,
                     "session_id": session_id,
                     "first_message_index": room_key["first_message_index"],
@@ -92,31 +135,39 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         )
 
     @trace
-    async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
+    async def get_e2e_room_keys(
+        self,
+        user_id: str,
+        version: str,
+        room_id: Optional[str] = None,
+        session_id: Optional[str] = None,
+    ) -> Dict[
+        Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
+    ]:
         """Bulk get the E2E room keys for a given backup, optionally filtered to a given
         room, or a given session.
 
         Args:
-            user_id (str): the user whose backup we're querying
-            version (str): the version ID of the backup for the set of keys we're querying
-            room_id (str): Optional. the ID of the room whose keys we're querying, if any.
+            user_id: the user whose backup we're querying
+            version: the version ID of the backup for the set of keys we're querying
+            room_id: Optional. the ID of the room whose keys we're querying, if any.
                 If not specified, we return the keys for all the rooms in the backup.
-            session_id (str): Optional. the session whose room_key we're querying, if any.
+            session_id: Optional. the session whose room_key we're querying, if any.
                 If specified, we also require the room_id to be specified.
                 If not specified, we return all the keys in this version of
                 the backup (or for the specified room)
 
         Returns:
-            A list of dicts giving the session_data and message metadata for
-            these room keys.
+            A dict giving the session_data and message metadata for these room keys.
+            `{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
         """
 
         try:
-            version = int(version)
+            version_int = int(version)
         except ValueError:
             return {"rooms": {}}
 
-        keyvalues = {"user_id": user_id, "version": version}
+        keyvalues = {"user_id": user_id, "version": version_int}
         if room_id:
             keyvalues["room_id"] = room_id
             if session_id:
@@ -137,7 +188,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             desc="get_e2e_room_keys",
         )
 
-        sessions = {"rooms": {}}
+        sessions: Dict[
+            Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
+        ] = {"rooms": {}}
         for row in rows:
             room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
             room_entry["sessions"][row["session_id"]] = {
@@ -150,7 +203,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
         return sessions
 
-    async def get_e2e_room_keys_multi(self, user_id, version, room_keys):
+    async def get_e2e_room_keys_multi(
+        self,
+        user_id: str,
+        version: str,
+        room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
+    ) -> Dict[str, Dict[str, RoomKey]]:
         """Get multiple room keys at a time.  The difference between this function and
         get_e2e_room_keys is that this function can be used to retrieve
         multiple specific keys at a time, whereas get_e2e_room_keys is used for
@@ -158,26 +216,36 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         specific key.
 
         Args:
-            user_id (str): the user whose backup we're querying
-            version (str): the version ID of the backup we're querying about
-            room_keys (dict[str, dict[str, iterable[str]]]): a map from
-                room ID -> {"session": [session ids]} indicating the session IDs
-                that we want to query
+            user_id: the user whose backup we're querying
+            version: the version ID of the backup we're querying about
+            room_keys: a map from room ID -> {"sessions": [session ids]}
+                indicating the session IDs that we want to query
 
         Returns:
-           dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key
+           A map of room IDs to session IDs to room key
         """
+        try:
+            version_int = int(version)
+        except ValueError:
+            # Our versions are all ints so if we can't convert it to an integer,
+            # it doesn't exist.
+            return {}
 
         return await self.db_pool.runInteraction(
             "get_e2e_room_keys_multi",
             self._get_e2e_room_keys_multi_txn,
             user_id,
-            version,
+            version_int,
             room_keys,
         )
 
     @staticmethod
-    def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys):
+    def _get_e2e_room_keys_multi_txn(
+        txn: LoggingTransaction,
+        user_id: str,
+        version: int,
+        room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
+    ) -> Dict[str, Dict[str, RoomKey]]:
         if not room_keys:
             return {}
 
@@ -209,7 +277,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
         txn.execute(sql, params)
 
-        ret = {}
+        ret: Dict[str, Dict[str, RoomKey]] = {}
 
         for row in txn:
             room_id = row[0]
@@ -231,36 +299,49 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             user_id: the user whose backup we're querying
             version: the version ID of the backup we're querying about
         """
+        try:
+            version_int = int(version)
+        except ValueError:
+            # Our versions are all ints so if we can't convert it to an integer,
+            # it doesn't exist.
+            return 0
 
         return await self.db_pool.simple_select_one_onecol(
             table="e2e_room_keys",
-            keyvalues={"user_id": user_id, "version": version},
+            keyvalues={"user_id": user_id, "version": version_int},
             retcol="COUNT(*)",
             desc="count_e2e_room_keys",
         )
 
     @trace
     async def delete_e2e_room_keys(
-        self, user_id, version, room_id=None, session_id=None
-    ):
+        self,
+        user_id: str,
+        version: str,
+        room_id: Optional[str] = None,
+        session_id: Optional[str] = None,
+    ) -> None:
         """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
         room or a given session.
 
         Args:
-            user_id(str): the user whose backup we're deleting from
-            version(str): the version ID of the backup for the set of keys we're deleting
-            room_id(str): Optional. the ID of the room whose keys we're deleting, if any.
+            user_id: the user whose backup we're deleting from
+            version: the version ID of the backup for the set of keys we're deleting
+            room_id: Optional. the ID of the room whose keys we're deleting, if any.
                 If not specified, we delete the keys for all the rooms in the backup.
-            session_id(str): Optional. the session whose room_key we're querying, if any.
+            session_id: Optional. the session whose room_key we're querying, if any.
                 If specified, we also require the room_id to be specified.
                 If not specified, we delete all the keys in this version of
                 the backup (or for the specified room)
-
-        Returns:
-            The deletion transaction
         """
+        try:
+            version_int = int(version)
+        except ValueError:
+            # Our versions are all ints so if we can't convert it to an integer,
+            # it doesn't exist.
+            return
 
-        keyvalues = {"user_id": user_id, "version": int(version)}
+        keyvalues = {"user_id": user_id, "version": version_int}
         if room_id:
             keyvalues["room_id"] = room_id
             if session_id:
@@ -271,23 +352,27 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         )
 
     @staticmethod
-    def _get_current_version(txn, user_id):
+    def _get_current_version(txn: LoggingTransaction, user_id: str) -> int:
         txn.execute(
             "SELECT MAX(version) FROM e2e_room_keys_versions "
             "WHERE user_id=? AND deleted=0",
             (user_id,),
         )
-        row = txn.fetchone()
-        if not row:
+        # `SELECT MAX() FROM ...` will always return 1 row. The value in that row will
+        # be `NULL` when there are no available versions.
+        row = cast(Tuple[Optional[int]], txn.fetchone())
+        if row[0] is None:
             raise StoreError(404, "No current backup version")
         return row[0]
 
-    async def get_e2e_room_keys_version_info(self, user_id, version=None):
+    async def get_e2e_room_keys_version_info(
+        self, user_id: str, version: Optional[str] = None
+    ) -> JsonDict:
         """Get info metadata about a version of our room_keys backup.
 
         Args:
-            user_id(str): the user whose backup we're querying
-            version(str): Optional. the version ID of the backup we're querying about
+            user_id: the user whose backup we're querying
+            version: Optional. the version ID of the backup we're querying about
                 If missing, we return the information about the current version.
         Raises:
             StoreError: with code 404 if there are no e2e_room_keys_versions present
@@ -300,7 +385,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 etag(int): tag of the keys in the backup
         """
 
-        def _get_e2e_room_keys_version_info_txn(txn):
+        def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict:
             if version is None:
                 this_version = self._get_current_version(txn, user_id)
             else:
@@ -309,14 +394,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 except ValueError:
                     # Our versions are all ints so if we can't convert it to an integer,
                     # it isn't there.
-                    raise StoreError(404, "No row found")
+                    raise StoreError(404, "No backup with that version exists")
 
             result = self.db_pool.simple_select_one_txn(
                 txn,
                 table="e2e_room_keys_versions",
                 keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
                 retcols=("version", "algorithm", "auth_data", "etag"),
+                allow_none=False,
             )
+            assert result is not None  # see comment on `simple_select_one_txn`
             result["auth_data"] = db_to_json(result["auth_data"])
             result["version"] = str(result["version"])
             if result["etag"] is None:
@@ -328,28 +415,28 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         )
 
     @trace
-    async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str:
+    async def create_e2e_room_keys_version(self, user_id: str, info: JsonDict) -> str:
         """Atomically creates a new version of this user's e2e_room_keys store
         with the given version info.
 
         Args:
-            user_id(str): the user whose backup we're creating a version
-            info(dict): the info about the backup version to be created
+            user_id: the user whose backup we're creating a version
+            info: the info about the backup version to be created
 
         Returns:
             The newly created version ID
         """
 
-        def _create_e2e_room_keys_version_txn(txn):
+        def _create_e2e_room_keys_version_txn(txn: LoggingTransaction) -> str:
             txn.execute(
                 "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?",
                 (user_id,),
             )
-            current_version = txn.fetchone()[0]
+            current_version = cast(Tuple[Optional[int]], txn.fetchone())[0]
             if current_version is None:
-                current_version = "0"
+                current_version = 0
 
-            new_version = str(int(current_version) + 1)
+            new_version = current_version + 1
 
             self.db_pool.simple_insert_txn(
                 txn,
@@ -362,7 +449,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 },
             )
 
-            return new_version
+            return str(new_version)
 
         return await self.db_pool.runInteraction(
             "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
@@ -373,7 +460,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         self,
         user_id: str,
         version: str,
-        info: Optional[dict] = None,
+        info: Optional[JsonDict] = None,
         version_etag: Optional[int] = None,
     ) -> None:
         """Update a given backup version
@@ -386,7 +473,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             version_etag: etag of the keys in the backup. If None, then the etag
                 is not updated.
         """
-        updatevalues = {}
+        updatevalues: Dict[str, object] = {}
 
         if info is not None and "auth_data" in info:
             updatevalues["auth_data"] = json_encoder.encode(info["auth_data"])
@@ -394,9 +481,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             updatevalues["etag"] = version_etag
 
         if updatevalues:
-            await self.db_pool.simple_update(
+            try:
+                version_int = int(version)
+            except ValueError:
+                # Our versions are all ints so if we can't convert it to an integer,
+                # it doesn't exist.
+                raise StoreError(404, "No backup with that version exists")
+
+            await self.db_pool.simple_update_one(
                 table="e2e_room_keys_versions",
-                keyvalues={"user_id": user_id, "version": version},
+                keyvalues={"user_id": user_id, "version": version_int},
                 updatevalues=updatevalues,
                 desc="update_e2e_room_keys_version",
             )
@@ -417,13 +511,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 or if the version requested doesn't exist.
         """
 
-        def _delete_e2e_room_keys_version_txn(txn):
+        def _delete_e2e_room_keys_version_txn(txn: LoggingTransaction) -> None:
             if version is None:
                 this_version = self._get_current_version(txn, user_id)
-                if this_version is None:
-                    raise StoreError(404, "No current backup version")
             else:
-                this_version = version
+                try:
+                    this_version = int(version)
+                except ValueError:
+                    # Our versions are all ints so if we can't convert it to an integer,
+                    # it isn't there.
+                    raise StoreError(404, "No backup with that version exists")
 
             self.db_pool.simple_delete_txn(
                 txn,