diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index b15fb71e62..0cb48b9dd7 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -13,35 +13,71 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
+from typing import Dict, Iterable, Mapping, Optional, Tuple, cast
+
+from typing_extensions import Literal, TypedDict
from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
+from synapse.types import JsonDict, JsonSerializable
from synapse.util import json_encoder
+class RoomKey(TypedDict):
+ """`KeyBackupData` in the Matrix spec.
+
+ https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid
+ """
+
+ first_message_index: int
+ forwarded_count: int
+ is_verified: bool
+ session_data: JsonSerializable
+
+
class EndToEndRoomKeyStore(SQLBaseStore):
+ """The store for end to end room key backups.
+
+ See https://spec.matrix.org/v1.1/client-server-api/#server-side-key-backups
+
+ As per the spec, backups are identified by an opaque version string. Internally,
+ version identifiers are assigned using incrementing integers. Non-numeric version
+ strings are treated as if they do not exist, since we would have never issued them.
+ """
+
async def update_e2e_room_key(
- self, user_id, version, room_id, session_id, room_key
- ):
+ self,
+ user_id: str,
+ version: str,
+ room_id: str,
+ session_id: str,
+ room_key: RoomKey,
+ ) -> None:
"""Replaces the encrypted E2E room key for a given session in a given backup
Args:
- user_id(str): the user whose backup we're setting
- version(str): the version ID of the backup we're updating
- room_id(str): the ID of the room whose keys we're setting
- session_id(str): the session whose room_key we're setting
- room_key(dict): the room_key being set
+ user_id: the user whose backup we're setting
+ version: the version ID of the backup we're updating
+ room_id: the ID of the room whose keys we're setting
+ session_id: the session whose room_key we're setting
+ room_key: the room_key being set
Raises:
StoreError
"""
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ raise StoreError(404, "No backup with that version exists")
await self.db_pool.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
- "version": version,
+ "version": version_int,
"room_id": room_id,
"session_id": session_id,
},
@@ -54,22 +90,29 @@ class EndToEndRoomKeyStore(SQLBaseStore):
desc="update_e2e_room_key",
)
- async def add_e2e_room_keys(self, user_id, version, room_keys):
+ async def add_e2e_room_keys(
+ self, user_id: str, version: str, room_keys: Iterable[Tuple[str, str, RoomKey]]
+ ) -> None:
"""Bulk add room keys to a given backup.
Args:
- user_id (str): the user whose backup we're adding to
- version (str): the version ID of the backup for the set of keys we're adding to
- room_keys (iterable[(str, str, dict)]): the keys to add, in the form
- (roomID, sessionID, keyData)
+ user_id: the user whose backup we're adding to
+ version: the version ID of the backup for the set of keys we're adding to
+ room_keys: the keys to add, in the form (roomID, sessionID, keyData)
"""
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ raise StoreError(404, "No backup with that version exists")
values = []
for (room_id, session_id, room_key) in room_keys:
values.append(
{
"user_id": user_id,
- "version": version,
+ "version": version_int,
"room_id": room_id,
"session_id": session_id,
"first_message_index": room_key["first_message_index"],
@@ -92,31 +135,39 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@trace
- async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def get_e2e_room_keys(
+ self,
+ user_id: str,
+ version: str,
+ room_id: Optional[str] = None,
+ session_id: Optional[str] = None,
+ ) -> Dict[
+ Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
+ ]:
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
Args:
- user_id (str): the user whose backup we're querying
- version (str): the version ID of the backup for the set of keys we're querying
- room_id (str): Optional. the ID of the room whose keys we're querying, if any.
+ user_id: the user whose backup we're querying
+ version: the version ID of the backup for the set of keys we're querying
+ room_id: Optional. the ID of the room whose keys we're querying, if any.
If not specified, we return the keys for all the rooms in the backup.
- session_id (str): Optional. the session whose room_key we're querying, if any.
+ session_id: Optional. the session whose room_key we're querying, if any.
If specified, we also require the room_id to be specified.
If not specified, we return all the keys in this version of
the backup (or for the specified room)
Returns:
- A list of dicts giving the session_data and message metadata for
- these room keys.
+ A dict giving the session_data and message metadata for these room keys.
+ `{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
"""
try:
- version = int(version)
+ version_int = int(version)
except ValueError:
return {"rooms": {}}
- keyvalues = {"user_id": user_id, "version": version}
+ keyvalues = {"user_id": user_id, "version": version_int}
if room_id:
keyvalues["room_id"] = room_id
if session_id:
@@ -137,7 +188,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
desc="get_e2e_room_keys",
)
- sessions = {"rooms": {}}
+ sessions: Dict[
+ Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
+ ] = {"rooms": {}}
for row in rows:
room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
room_entry["sessions"][row["session_id"]] = {
@@ -150,7 +203,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return sessions
- async def get_e2e_room_keys_multi(self, user_id, version, room_keys):
+ async def get_e2e_room_keys_multi(
+ self,
+ user_id: str,
+ version: str,
+ room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
+ ) -> Dict[str, Dict[str, RoomKey]]:
"""Get multiple room keys at a time. The difference between this function and
get_e2e_room_keys is that this function can be used to retrieve
multiple specific keys at a time, whereas get_e2e_room_keys is used for
@@ -158,26 +216,36 @@ class EndToEndRoomKeyStore(SQLBaseStore):
specific key.
Args:
- user_id (str): the user whose backup we're querying
- version (str): the version ID of the backup we're querying about
- room_keys (dict[str, dict[str, iterable[str]]]): a map from
- room ID -> {"session": [session ids]} indicating the session IDs
- that we want to query
+ user_id: the user whose backup we're querying
+ version: the version ID of the backup we're querying about
+ room_keys: a map from room ID -> {"sessions": [session ids]}
+ indicating the session IDs that we want to query
Returns:
- dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key
+ A map of room IDs to session IDs to room key
"""
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ return {}
return await self.db_pool.runInteraction(
"get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn,
user_id,
- version,
+ version_int,
room_keys,
)
@staticmethod
- def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys):
+ def _get_e2e_room_keys_multi_txn(
+ txn: LoggingTransaction,
+ user_id: str,
+ version: int,
+ room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
+ ) -> Dict[str, Dict[str, RoomKey]]:
if not room_keys:
return {}
@@ -209,7 +277,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
txn.execute(sql, params)
- ret = {}
+ ret: Dict[str, Dict[str, RoomKey]] = {}
for row in txn:
room_id = row[0]
@@ -231,36 +299,49 @@ class EndToEndRoomKeyStore(SQLBaseStore):
user_id: the user whose backup we're querying
version: the version ID of the backup we're querying about
"""
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ return 0
return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys",
- keyvalues={"user_id": user_id, "version": version},
+ keyvalues={"user_id": user_id, "version": version_int},
retcol="COUNT(*)",
desc="count_e2e_room_keys",
)
@trace
async def delete_e2e_room_keys(
- self, user_id, version, room_id=None, session_id=None
- ):
+ self,
+ user_id: str,
+ version: str,
+ room_id: Optional[str] = None,
+ session_id: Optional[str] = None,
+ ) -> None:
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
Args:
- user_id(str): the user whose backup we're deleting from
- version(str): the version ID of the backup for the set of keys we're deleting
- room_id(str): Optional. the ID of the room whose keys we're deleting, if any.
+ user_id: the user whose backup we're deleting from
+ version: the version ID of the backup for the set of keys we're deleting
+ room_id: Optional. the ID of the room whose keys we're deleting, if any.
If not specified, we delete the keys for all the rooms in the backup.
- session_id(str): Optional. the session whose room_key we're querying, if any.
+ session_id: Optional. the session whose room_key we're querying, if any.
If specified, we also require the room_id to be specified.
If not specified, we delete all the keys in this version of
the backup (or for the specified room)
-
- Returns:
- The deletion transaction
"""
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ return
- keyvalues = {"user_id": user_id, "version": int(version)}
+ keyvalues = {"user_id": user_id, "version": version_int}
if room_id:
keyvalues["room_id"] = room_id
if session_id:
@@ -271,23 +352,27 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@staticmethod
- def _get_current_version(txn, user_id):
+ def _get_current_version(txn: LoggingTransaction, user_id: str) -> int:
txn.execute(
"SELECT MAX(version) FROM e2e_room_keys_versions "
"WHERE user_id=? AND deleted=0",
(user_id,),
)
- row = txn.fetchone()
- if not row:
+ # `SELECT MAX() FROM ...` will always return 1 row. The value in that row will
+ # be `NULL` when there are no available versions.
+ row = cast(Tuple[Optional[int]], txn.fetchone())
+ if row[0] is None:
raise StoreError(404, "No current backup version")
return row[0]
- async def get_e2e_room_keys_version_info(self, user_id, version=None):
+ async def get_e2e_room_keys_version_info(
+ self, user_id: str, version: Optional[str] = None
+ ) -> JsonDict:
"""Get info metadata about a version of our room_keys backup.
Args:
- user_id(str): the user whose backup we're querying
- version(str): Optional. the version ID of the backup we're querying about
+ user_id: the user whose backup we're querying
+ version: Optional. the version ID of the backup we're querying about
If missing, we return the information about the current version.
Raises:
StoreError: with code 404 if there are no e2e_room_keys_versions present
@@ -300,7 +385,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
etag(int): tag of the keys in the backup
"""
- def _get_e2e_room_keys_version_info_txn(txn):
+ def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict:
if version is None:
this_version = self._get_current_version(txn, user_id)
else:
@@ -309,14 +394,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it isn't there.
- raise StoreError(404, "No row found")
+ raise StoreError(404, "No backup with that version exists")
result = self.db_pool.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
retcols=("version", "algorithm", "auth_data", "etag"),
+ allow_none=False,
)
+ assert result is not None # see comment on `simple_select_one_txn`
result["auth_data"] = db_to_json(result["auth_data"])
result["version"] = str(result["version"])
if result["etag"] is None:
@@ -328,28 +415,28 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@trace
- async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str:
+ async def create_e2e_room_keys_version(self, user_id: str, info: JsonDict) -> str:
"""Atomically creates a new version of this user's e2e_room_keys store
with the given version info.
Args:
- user_id(str): the user whose backup we're creating a version
- info(dict): the info about the backup version to be created
+ user_id: the user whose backup we're creating a version
+ info: the info about the backup version to be created
Returns:
The newly created version ID
"""
- def _create_e2e_room_keys_version_txn(txn):
+ def _create_e2e_room_keys_version_txn(txn: LoggingTransaction) -> str:
txn.execute(
"SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?",
(user_id,),
)
- current_version = txn.fetchone()[0]
+ current_version = cast(Tuple[Optional[int]], txn.fetchone())[0]
if current_version is None:
- current_version = "0"
+ current_version = 0
- new_version = str(int(current_version) + 1)
+ new_version = current_version + 1
self.db_pool.simple_insert_txn(
txn,
@@ -362,7 +449,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
},
)
- return new_version
+ return str(new_version)
return await self.db_pool.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
@@ -373,7 +460,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
self,
user_id: str,
version: str,
- info: Optional[dict] = None,
+ info: Optional[JsonDict] = None,
version_etag: Optional[int] = None,
) -> None:
"""Update a given backup version
@@ -386,7 +473,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
version_etag: etag of the keys in the backup. If None, then the etag
is not updated.
"""
- updatevalues = {}
+ updatevalues: Dict[str, object] = {}
if info is not None and "auth_data" in info:
updatevalues["auth_data"] = json_encoder.encode(info["auth_data"])
@@ -394,9 +481,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
updatevalues["etag"] = version_etag
if updatevalues:
- await self.db_pool.simple_update(
+ try:
+ version_int = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it doesn't exist.
+ raise StoreError(404, "No backup with that version exists")
+
+ await self.db_pool.simple_update_one(
table="e2e_room_keys_versions",
- keyvalues={"user_id": user_id, "version": version},
+ keyvalues={"user_id": user_id, "version": version_int},
updatevalues=updatevalues,
desc="update_e2e_room_keys_version",
)
@@ -417,13 +511,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
or if the version requested doesn't exist.
"""
- def _delete_e2e_room_keys_version_txn(txn):
+ def _delete_e2e_room_keys_version_txn(txn: LoggingTransaction) -> None:
if version is None:
this_version = self._get_current_version(txn, user_id)
- if this_version is None:
- raise StoreError(404, "No current backup version")
else:
- this_version = version
+ try:
+ this_version = int(version)
+ except ValueError:
+ # Our versions are all ints so if we can't convert it to an integer,
+ # it isn't there.
+ raise StoreError(404, "No backup with that version exists")
self.db_pool.simple_delete_txn(
txn,
|