summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/sync.py11
-rw-r--r--synapse/replication/slave/storage/devices.py2
-rw-r--r--synapse/rest/client/v2_alpha/sync.py1
-rw-r--r--synapse/storage/end_to_end_keys.py35
4 files changed, 34 insertions, 15 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c0205da1a9..91c6c6be3c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -117,6 +117,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
     "archived",  # ArchivedSyncResult for each archived room.
     "to_device",  # List of direct messages for the device.
     "device_lists",  # List of user_ids whose devices have chanegd
+    "device_one_time_keys_count",  # Dict of algorithm to count for one time keys
+                                   # for this device
 ])):
     __slots__ = []
 
@@ -550,6 +552,14 @@ class SyncHandler(object):
             sync_result_builder
         )
 
+        device_id = sync_config.device_id
+        one_time_key_counts = {}
+        if device_id:
+            user_id = sync_config.user.to_string()
+            one_time_key_counts = yield self.store.count_e2e_one_time_keys(
+                user_id, device_id
+            )
+
         defer.returnValue(SyncResult(
             presence=sync_result_builder.presence,
             account_data=sync_result_builder.account_data,
@@ -558,6 +568,7 @@ class SyncHandler(object):
             archived=sync_result_builder.archived,
             to_device=sync_result_builder.to_device,
             device_lists=device_lists,
+            device_one_time_keys_count=one_time_key_counts,
             next_batch=sync_result_builder.now_token,
         ))
 
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 4d4a435471..7687867aee 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -16,6 +16,7 @@
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 from synapse.storage import DataStore
+from synapse.storage.end_to_end_keys import EndToEndKeyStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
@@ -45,6 +46,7 @@ class SlavedDeviceStore(BaseSlavedStore):
     _mark_as_sent_devices_by_remote_txn = (
         DataStore._mark_as_sent_devices_by_remote_txn.__func__
     )
+    count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]
 
     def stream_positions(self):
         result = super(SlavedDeviceStore, self).stream_positions()
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 771e127ab9..83e209d18f 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -192,6 +192,7 @@ class SyncRestServlet(RestServlet):
                 "invite": invited,
                 "leave": archived,
             },
+            "device_one_time_keys_count": sync_result.device_one_time_keys_count,
             "next_batch": sync_result.next_batch.to_string(),
         }
 
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index e00f31da2b..2cebb203c6 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -185,8 +185,8 @@ class EndToEndKeyStore(SQLBaseStore):
                     for algorithm, key_id, json_bytes in new_keys
                 ],
             )
-            txn.call_after(
-                self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
+            self._invalidate_cache_and_stream(
+                txn, self.count_e2e_one_time_keys, (user_id, device_id,)
             )
         yield self.runInteraction(
             "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
@@ -237,24 +237,29 @@ class EndToEndKeyStore(SQLBaseStore):
             )
             for user_id, device_id, algorithm, key_id in delete:
                 txn.execute(sql, (user_id, device_id, algorithm, key_id))
-                txn.call_after(
-                    self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
+                self._invalidate_cache_and_stream(
+                    txn, self.count_e2e_one_time_keys, (user_id, device_id,)
                 )
             return result
         return self.runInteraction(
             "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
         )
 
-    @defer.inlineCallbacks
     def delete_e2e_keys_by_device(self, user_id, device_id):
-        yield self._simple_delete(
-            table="e2e_device_keys_json",
-            keyvalues={"user_id": user_id, "device_id": device_id},
-            desc="delete_e2e_device_keys_by_device"
-        )
-        yield self._simple_delete(
-            table="e2e_one_time_keys_json",
-            keyvalues={"user_id": user_id, "device_id": device_id},
-            desc="delete_e2e_one_time_keys_by_device"
+        def delete_e2e_keys_by_device_txn(txn):
+            self._simple_delete_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+            )
+            self._simple_delete_txn(
+                txn,
+                table="e2e_one_time_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.count_e2e_one_time_keys, (user_id, device_id,)
+            )
+        return self.runInteraction(
+            "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
         )
-        self.count_e2e_one_time_keys.invalidate((user_id, device_id,))