summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2015-07-08 17:04:29 +0100
committerMark Haines <mark.haines@matrix.org>2015-07-08 17:04:29 +0100
commit8fb79eeea4b1c388771785024b79e84b4206fc24 (patch)
treea208ccabc4a18158ab48d6d94ae135813d933c1f /synapse
parentMerge branch 'mjark/missing_regex_group' into markjh/client-end-to-end-key-ma... (diff)
downloadsynapse-8fb79eeea4b1c388771785024b79e84b4206fc24.tar.xz
Only remove one time keys when new one time keys are added
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/end_to_end_keys.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 936a64669c..b3cede37e3 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -58,6 +58,11 @@ class EndToEndKeyStore(SQLBaseStore):
     def add_e2e_one_time_keys(self, user_id, device_id, time_now, valid_until,
                               key_list):
         def _add_e2e_one_time_keys(txn):
+            sql = (
+                "DELETE FROM e2e_one_time_keys_json"
+                " WHERE user_id = ? AND device_id = ? AND valid_until_ms < ?"
+            )
+            txn.execute(sql, (user_id, device_id, time_now))
             for (algorithm, key_id, json_bytes) in key_list:
                 self._simple_upsert_txn(
                     txn, table="e2e_one_time_keys_json",
@@ -84,16 +89,11 @@ class EndToEndKeyStore(SQLBaseStore):
         """
         def _count_e2e_one_time_keys(txn):
             sql = (
-                "DELETE FROM e2e_one_time_keys_json"
-                " WHERE user_id = ? AND device_id = ? AND valid_until_ms < ?"
-            )
-            txn.execute(sql, (user_id, device_id, time_now))
-            sql = (
                 "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
-                " WHERE user_id = ? AND device_id = ?"
+                " WHERE user_id = ? AND device_id = ? AND valid_until_ms >= ?"
                 " GROUP BY algorithm"
             )
-            txn.execute(sql, (user_id, device_id))
+            txn.execute(sql, (user_id, device_id, time_now))
             result = {}
             for algorithm, key_count in txn.fetchall():
                 result[algorithm] = key_count