From 16b67c404d232bd133f03adc5e5eba60610f941f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 6 May 2020 00:08:36 +0100 Subject: Make get_e2e_cross_signing_key delegate to get_e2e_cross_signing_keys_bulk ... mostly because the latter has a cache. --- .../storage/data_stores/main/end_to_end_keys.py | 60 +++------------------- 1 file changed, 6 insertions(+), 54 deletions(-) (limited to 'synapse') diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index d0b9742695..20698bfd16 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -270,53 +270,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) - def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None): - """Returns a user's cross-signing key. - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - user_id (str): the user whose key is being requested - key_type (str): the type of key that is being requested: either 'master' - for a master key, 'self_signing' for a self-signing key, or - 'user_signing' for a user-signing key - from_user_id (str): if specified, signatures made by this user on - the key will be included in the result - - Returns: - dict of the key data or None if not found - """ - sql = ( - "SELECT keydata " - " FROM e2e_cross_signing_keys " - " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1" - ) - txn.execute(sql, (user_id, key_type)) - row = txn.fetchone() - if not row: - return None - key = json.loads(row[0]) - - device_id = None - for k in key["keys"].values(): - device_id = k - - if from_user_id is not None: - sql = ( - "SELECT key_id, signature " - " FROM e2e_cross_signing_signatures " - " WHERE user_id = ? " - " AND target_user_id = ? " - " AND target_device_id = ? " - ) - txn.execute(sql, (from_user_id, user_id, device_id)) - row = txn.fetchone() - if row: - key.setdefault("signatures", {}).setdefault(from_user_id, {})[ - row[0] - ] = row[1] - - return key - + @defer.inlineCallbacks def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): """Returns a user's cross-signing key. @@ -331,13 +285,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Returns: dict of the key data or None if not found """ - return self.db.runInteraction( - "get_e2e_cross_signing_key", - self._get_e2e_cross_signing_key_txn, - user_id, - key_type, - from_user_id, - ) + res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) + user_keys = res.get(user_id) + if not user_keys: + return None + return user_keys.get(key_type) @cached(num_args=1) def _get_bare_e2e_cross_signing_keys(self, user_id): -- cgit 1.4.1