diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 9ff2d8d8c3..065145c0d2 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -143,9 +143,6 @@ class DataStore(
("device_lists_outbound_pokes", "stream_id"),
],
)
- self._cross_signing_id_gen = StreamIdGenerator(
- db_conn, "e2e_cross_signing_keys", "stream_id"
- )
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b06c1dc45b..57b5ffbad3 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,19 +14,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ cast,
+)
import attr
from canonicaljson import encode_canonical_json
-from twisted.enterprise.adbapi import Connection
-
from synapse.api.constants import DeviceKeyAlgorithms
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -50,7 +63,12 @@ class DeviceKeyLookupResult:
class EndToEndKeyBackgroundStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -62,8 +80,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore):
)
-class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._allow_device_name_lookup_over_federation = (
@@ -124,7 +147,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
- rv = {}
+ rv: Dict[str, Dict[str, JsonDict]] = {}
for user_id, device_keys in results.items():
rv[user_id] = {}
for device_id, device_info in device_keys.items():
@@ -195,6 +218,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
# add each cross-signing signature to the correct device in the result dict.
for (user_id, key_id, device_id, signature) in cross_sigs_result:
target_device_result = result[user_id][device_id]
+ # We've only looked up cross-signatures for non-deleted devices with key
+ # data.
+ assert target_device_result is not None
+ assert target_device_result.keys is not None
target_device_signatures = target_device_result.keys.setdefault(
"signatures", {}
)
@@ -207,7 +234,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
return result
def _get_e2e_device_keys_txn(
- self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+ self,
+ txn: LoggingTransaction,
+ query_list: Collection[Tuple[str, str]],
+ include_all_devices: bool = False,
+ include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Get information on devices from the database
@@ -263,7 +294,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
return result
def _get_e2e_cross_signing_signatures_for_devices_txn(
- self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
+ self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
) -> List[Tuple[str, str, str, str]]:
"""Get cross-signing signatures for a given list of devices
@@ -289,7 +320,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
)
txn.execute(signature_sql, signature_query_params)
- return txn.fetchall()
+ return cast(
+ List[
+ Tuple[
+ str,
+ str,
+ str,
+ str,
+ ]
+ ],
+ txn.fetchall(),
+ )
async def get_e2e_one_time_keys(
self, user_id: str, device_id: str, key_ids: List[str]
@@ -335,7 +376,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
"""
- def _add_e2e_one_time_keys(txn):
+ def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("new_keys", new_keys)
@@ -375,7 +416,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
A mapping from algorithm to number of keys for that algorithm.
"""
- def _count_e2e_one_time_keys(txn):
+ def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]:
sql = (
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ?"
@@ -421,7 +462,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
)
def _set_e2e_fallback_keys_txn(
- self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ fallback_keys: JsonDict,
) -> None:
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
@@ -483,7 +528,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def get_e2e_cross_signing_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
- ) -> Optional[dict]:
+ ) -> Optional[JsonDict]:
"""Returns a user's cross-signing key.
Args:
@@ -504,7 +549,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
return user_keys.get(key_type)
@cached(num_args=1)
- def _get_bare_e2e_cross_signing_keys(self, user_id):
+ def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
"""Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk.
"""
@@ -517,7 +562,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
)
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: Iterable[str]
- ) -> Dict[str, Dict[str, dict]]:
+ ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
@@ -531,32 +576,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
their user ID will map to None.
"""
- return await self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
)
+ # The `Optional` comes from the `@cachedList` decorator.
+ return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
+
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
- txn: Connection,
+ txn: LoggingTransaction,
user_ids: Iterable[str],
- ) -> Dict[str, Dict[str, dict]]:
+ ) -> Dict[str, Dict[str, JsonDict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- user_ids (list[str]): the users whose keys are being requested
+ txn: db connection
+ user_ids: the users whose keys are being requested
Returns:
- dict[str, dict[str, dict]]: mapping from user ID to key type to key
- data. If a user's cross-signing keys were not found, their user
- ID will not be in the dict.
+ Mapping from user ID to key type to key data.
+ If a user's cross-signing keys were not found, their user ID will not be in
+ the dict.
"""
- result = {}
+ result: Dict[str, Dict[str, JsonDict]] = {}
for user_chunk in batch_iter(user_ids, 100):
clause, params = make_in_list_sql_clause(
@@ -596,43 +644,48 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
user_id = row["user_id"]
key_type = row["keytype"]
key = db_to_json(row["keydata"])
- user_info = result.setdefault(user_id, {})
- user_info[key_type] = key
+ user_keys = result.setdefault(user_id, {})
+ user_keys[key_type] = key
return result
def _get_e2e_cross_signing_signatures_txn(
self,
- txn: Connection,
- keys: Dict[str, Dict[str, dict]],
+ txn: LoggingTransaction,
+ keys: Dict[str, Optional[Dict[str, JsonDict]]],
from_user_id: str,
- ) -> Dict[str, Dict[str, dict]]:
+ ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
"""Returns the cross-signing signatures made by a user on a set of keys.
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- keys (dict[str, dict[str, dict]]): a map of user ID to key type to
- key data. This dict will be modified to add signatures.
- from_user_id (str): fetch the signatures made by this user
+ txn: db connection
+ keys: a map of user ID to key type to key data.
+ This dict will be modified to add signatures.
+ from_user_id: fetch the signatures made by this user
Returns:
- dict[str, dict[str, dict]]: mapping from user ID to key type to key
- data. The return value will be the same as the keys argument,
- with the modifications included.
+ Mapping from user ID to key type to key data.
+ The return value will be the same as the keys argument, with the
+ modifications included.
"""
# find out what cross-signing keys (a.k.a. devices) we need to get
# signatures for. This is a map of (user_id, device_id) to key type
# (device_id is the key's public part).
- devices = {}
+ devices: Dict[Tuple[str, str], str] = {}
- for user_id, user_info in keys.items():
- if user_info is None:
+ for user_id, user_keys in keys.items():
+ if user_keys is None:
continue
- for key_type, key in user_info.items():
+ for key_type, key in user_keys.items():
device_id = None
for k in key["keys"].values():
device_id = k
+ # `key` ought to be a `CrossSigningKey`, whose .keys property is a
+ # dictionary with a single entry:
+ # "algorithm:base64_public_key": "base64_public_key"
+ # See https://spec.matrix.org/v1.1/client-server-api/#cross-signing
+ assert isinstance(device_id, str)
devices[(user_id, device_id)] = key_type
for batch in batch_iter(devices.keys(), size=100):
@@ -656,15 +709,20 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
# and add the signatures to the appropriate keys
for row in rows:
- key_id = row["key_id"]
- target_user_id = row["target_user_id"]
- target_device_id = row["target_device_id"]
+ key_id: str = row["key_id"]
+ target_user_id: str = row["target_user_id"]
+ target_device_id: str = row["target_device_id"]
key_type = devices[(target_user_id, target_device_id)]
# We need to copy everything, because the result may have come
# from the cache. dict.copy only does a shallow copy, so we
# need to recursively copy the dicts that will be modified.
- user_info = keys[target_user_id] = keys[target_user_id].copy()
- target_user_key = user_info[key_type] = user_info[key_type].copy()
+ user_keys = keys[target_user_id]
+ # `user_keys` cannot be `None` because we only fetched signatures for
+ # users with keys
+ assert user_keys is not None
+ user_keys = keys[target_user_id] = user_keys.copy()
+
+ target_user_key = user_keys[key_type] = user_keys[key_type].copy()
if "signatures" in target_user_key:
signatures = target_user_key["signatures"] = target_user_key[
"signatures"
@@ -683,7 +741,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
- ) -> Dict[str, Optional[Dict[str, dict]]]:
+ ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users.
Args:
@@ -741,7 +799,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
if last_id == current_id:
return [], current_id, False
- def _get_all_user_signature_changes_for_remotes_txn(txn):
+ def _get_all_user_signature_changes_for_remotes_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
sql = """
SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream
@@ -785,7 +845,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
@trace
def _claim_e2e_one_time_key_simple(
- txn, user_id: str, device_id: str, algorithm: str
+ txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
) -> Optional[Tuple[str, str]]:
"""Claim OTK for device for DBs that don't support RETURNING.
@@ -825,7 +885,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
@trace
def _claim_e2e_one_time_key_returning(
- txn, user_id: str, device_id: str, algorithm: str
+ txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
) -> Optional[Tuple[str, str]]:
"""Claim OTK for device for DBs that support RETURNING.
@@ -860,7 +920,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
key_id, key_json = otk_row
return f"{algorithm}:{key_id}", key_json
- results = {}
+ results: Dict[str, Dict[str, Dict[str, str]]] = {}
for user_id, device_id, algorithm in query_list:
if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that
@@ -930,6 +990,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ self._cross_signing_id_gen = StreamIdGenerator(
+ db_conn, "e2e_cross_signing_keys", "stream_id"
+ )
+
async def set_e2e_device_keys(
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
) -> bool:
@@ -937,7 +1009,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
or the keys were already in the database.
"""
- def _set_e2e_device_keys_txn(txn):
+ def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
@@ -973,7 +1045,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
- def delete_e2e_keys_by_device_txn(txn):
+ def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
log_kv(
{
"message": "Deleting keys for device",
@@ -1012,17 +1084,24 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
- def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
+ def _set_e2e_cross_signing_key_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ key_type: str,
+ key: JsonDict,
+ stream_id: int,
+ ) -> None:
"""Set a user's cross-signing key.
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- user_id (str): the user to set the signing key for
- key_type (str): the type of key that is being set: either 'master'
+ txn: db connection
+ user_id: the user to set the signing key for
+ key_type: the type of key that is being set: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
- key (dict): the key data
- stream_id (int)
+ key: the key data
+ stream_id
"""
# the 'key' dict will look something like:
# {
@@ -1075,13 +1154,15 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
)
- async def set_e2e_cross_signing_key(self, user_id, key_type, key):
+ async def set_e2e_cross_signing_key(
+ self, user_id: str, key_type: str, key: JsonDict
+ ) -> None:
"""Set a user's cross-signing key.
Args:
- user_id (str): the user to set the user-signing key for
- key_type (str): the type of cross-signing key to set
- key (dict): the key data
+ user_id: the user to set the user-signing key for
+ key_type: the type of cross-signing key to set
+ key: the key data
"""
async with self._cross_signing_id_gen.get_next() as stream_id:
|