diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 936a64669c..b3cede37e3 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -58,6 +58,11 @@ class EndToEndKeyStore(SQLBaseStore):
def add_e2e_one_time_keys(self, user_id, device_id, time_now, valid_until,
key_list):
def _add_e2e_one_time_keys(txn):
+ sql = (
+ "DELETE FROM e2e_one_time_keys_json"
+ " WHERE user_id = ? AND device_id = ? AND valid_until_ms < ?"
+ )
+ txn.execute(sql, (user_id, device_id, time_now))
for (algorithm, key_id, json_bytes) in key_list:
self._simple_upsert_txn(
txn, table="e2e_one_time_keys_json",
@@ -84,16 +89,11 @@ class EndToEndKeyStore(SQLBaseStore):
"""
def _count_e2e_one_time_keys(txn):
sql = (
- "DELETE FROM e2e_one_time_keys_json"
- " WHERE user_id = ? AND device_id = ? AND valid_until_ms < ?"
- )
- txn.execute(sql, (user_id, device_id, time_now))
- sql = (
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
- " WHERE user_id = ? AND device_id = ?"
+ " WHERE user_id = ? AND device_id = ? AND valid_until_ms >= ?"
" GROUP BY algorithm"
)
- txn.execute(sql, (user_id, device_id))
+ txn.execute(sql, (user_id, device_id, time_now))
result = {}
for algorithm, key_count in txn.fetchall():
result[algorithm] = key_count
|