summary refs log tree commit diff
path: root/synapse/storage/end_to_end_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/end_to_end_keys.py')
-rw-r--r--synapse/storage/end_to_end_keys.py71
1 files changed, 31 insertions, 40 deletions
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py

index e381e472a2..2fabb9e2cb 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py
@@ -26,8 +26,7 @@ from ._base import SQLBaseStore, db_to_json class EndToEndKeyWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_e2e_device_keys( - self, query_list, include_all_devices=False, - include_deleted_devices=False, + self, query_list, include_all_devices=False, include_deleted_devices=False ): """Fetch a list of device keys. Args: @@ -45,8 +44,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore): defer.returnValue({}) results = yield self.runInteraction( - "get_e2e_device_keys", self._get_e2e_device_keys_txn, - query_list, include_all_devices, include_deleted_devices, + "get_e2e_device_keys", + self._get_e2e_device_keys_txn, + query_list, + include_all_devices, + include_deleted_devices, ) for user_id, device_keys in iteritems(results): @@ -56,8 +58,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): defer.returnValue(results) def _get_e2e_device_keys_txn( - self, txn, query_list, include_all_devices=False, - include_deleted_devices=False, + self, txn, query_list, include_all_devices=False, include_deleted_devices=False ): query_clauses = [] query_params = [] @@ -87,7 +88,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): " WHERE %s" ) % ( "LEFT" if include_all_devices else "INNER", - " OR ".join("(" + q + ")" for q in query_clauses) + " OR ".join("(" + q + ")" for q in query_clauses), ) txn.execute(sql, query_params) @@ -124,17 +125,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore): table="e2e_one_time_keys_json", column="key_id", iterable=key_ids, - retcols=("algorithm", "key_id", "key_json",), - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, + retcols=("algorithm", "key_id", "key_json"), + keyvalues={"user_id": user_id, "device_id": device_id}, desc="add_e2e_one_time_keys_check", ) - defer.returnValue({ - (row["algorithm"], row["key_id"]): row["key_json"] for row in rows - }) + defer.returnValue( + {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows} + ) @defer.inlineCallbacks def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): @@ -155,7 +153,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # `add_e2e_one_time_keys` then they'll conflict and we will only # insert one set. self._simple_insert_many_txn( - txn, table="e2e_one_time_keys_json", + txn, + table="e2e_one_time_keys_json", values=[ { "user_id": user_id, @@ -169,8 +168,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore): ], ) self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id,) + txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + yield self.runInteraction( "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys ) @@ -181,6 +181,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Returns: Dict mapping from algorithm to number of keys for that algorithm. """ + def _count_e2e_one_time_keys(txn): sql = ( "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" @@ -192,9 +193,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for algorithm, key_count in txn: result[algorithm] = key_count return result - return self.runInteraction( - "count_e2e_one_time_keys", _count_e2e_one_time_keys - ) + + return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys) class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): @@ -202,14 +202,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): """Stores device keys for a device. Returns whether there was a change or the keys were already in the database. """ + def _set_e2e_device_keys_txn(txn): old_key_json = self._simple_select_one_onecol_txn( txn, table="e2e_device_keys_json", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, + keyvalues={"user_id": user_id, "device_id": device_id}, retcol="key_json", allow_none=True, ) @@ -224,24 +222,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): self._simple_upsert_txn( txn, table="e2e_device_keys_json", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - values={ - "ts_added_ms": time_now, - "key_json": new_key_json, - } + keyvalues={"user_id": user_id, "device_id": device_id}, + values={"ts_added_ms": time_now, "key_json": new_key_json}, ) return True - return self.runInteraction( - "set_e2e_device_keys", _set_e2e_device_keys_txn - ) + return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) def claim_e2e_one_time_keys(self, query_list): """Take a list of one time keys out of the database""" + def _claim_e2e_one_time_keys(txn): sql = ( "SELECT key_id, key_json FROM e2e_one_time_keys_json" @@ -265,12 +256,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): for user_id, device_id, algorithm, key_id in delete: txn.execute(sql, (user_id, device_id, algorithm, key_id)) self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id,) + txn, self.count_e2e_one_time_keys, (user_id, device_id) ) return result - return self.runInteraction( - "claim_e2e_one_time_keys", _claim_e2e_one_time_keys - ) + + return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys) def delete_e2e_keys_by_device(self, user_id, device_id): def delete_e2e_keys_by_device_txn(txn): @@ -285,8 +275,9 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): keyvalues={"user_id": user_id, "device_id": device_id}, ) self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id,) + txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + return self.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn )