summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/device.py2
-rw-r--r--synapse/handlers/e2e_keys.py11
-rw-r--r--synapse/storage/devices.py3
-rw-r--r--synapse/storage/end_to_end_keys.py62
-rw-r--r--synapse/storage/roommember.py2
-rw-r--r--tests/storage/test_end_to_end_keys.py8
6 files changed, 63 insertions, 25 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 6fefb85890..7245d14fab 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -203,7 +203,7 @@ class DeviceHandler(BaseHandler):
         hosts = set()
         if self.hs.is_mine_id(user_id):
             for room_id in room_ids:
-                users = yield self.state.get_current_user_in_room(room_id)
+                users = yield self.store.get_users_in_room(room_id)
                 hosts.update(get_domain_from_id(u) for u in users)
             hosts.discard(self.server_name)
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index a16b9def8d..e40495d1ab 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -194,7 +194,7 @@ class E2eKeysHandler(object):
         # "unsigned" section
         for user_id, device_keys in results.items():
             for device_id, device_info in device_keys.items():
-                r = json.loads(device_info["key_json"])
+                r = dict(device_info["keys"])
                 r["unsigned"] = {}
                 display_name = device_info["device_display_name"]
                 if display_name is not None:
@@ -287,11 +287,12 @@ class E2eKeysHandler(object):
                 device_id, user_id, time_now
             )
             # TODO: Sign the JSON with the server key
-            yield self.store.set_e2e_device_keys(
-                user_id, device_id, time_now,
-                encode_canonical_json(device_keys)
+            changed = yield self.store.set_e2e_device_keys(
+                user_id, device_id, time_now, device_keys,
             )
-            yield self.device_handler.notify_device_update(user_id, [device_id])
+            if changed:
+                # Only notify about device updates *if* the keys actually changed
+                yield self.device_handler.notify_device_update(user_id, [device_id])
 
         one_time_keys = keys.get("one_time_keys", None)
         if one_time_keys:
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index e68ee50152..f0353929da 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -164,6 +164,7 @@ class DeviceStore(SQLBaseStore):
             keyvalues={
                 "user_id": user_id,
             },
+            desc="mark_remote_user_device_list_as_unsubscribed",
         )
 
     def update_remote_device_list_cache_entry(self, user_id, device_id, content,
@@ -463,7 +464,7 @@ class DeviceStore(SQLBaseStore):
             SELECT user_id FROM device_lists_stream WHERE stream_id > ?
         """
         rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
-        defer.returnValue(set(row["user_id"] for row in rows))
+        defer.returnValue(set(row[0] for row in rows))
 
     def get_all_device_list_changes_for_remotes(self, from_key):
         """Return a list of `(stream_id, user_id, destination)` which is the
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 85763f7ceb..2040e022fa 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -14,23 +14,53 @@
 # limitations under the License.
 from twisted.internet import defer
 
+from canonicaljson import encode_canonical_json
+import ujson as json
+
 from ._base import SQLBaseStore
 
 
 class EndToEndKeyStore(SQLBaseStore):
-    def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes):
-        return self._simple_upsert(
-            table="e2e_device_keys_json",
-            keyvalues={
-                "user_id": user_id,
-                "device_id": device_id,
-            },
-            values={
-                "ts_added_ms": time_now,
-                "key_json": json_bytes,
-            }
+    def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+        """Stores device keys for a device. Returns whether there was a change
+        or the keys were already in the database.
+        """
+        def _set_e2e_device_keys_txn(txn):
+            old_key_json = self._simple_select_one_onecol_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+                retcol="key_json",
+                allow_none=True,
+            )
+
+            new_key_json = encode_canonical_json(device_keys)
+            if old_key_json == new_key_json:
+                return False
+
+            self._simple_upsert_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+                values={
+                    "ts_added_ms": time_now,
+                    "key_json": new_key_json,
+                }
+            )
+
+            return True
+
+        return self.runInteraction(
+            "set_e2e_device_keys", _set_e2e_device_keys_txn
         )
 
+    @defer.inlineCallbacks
     def get_e2e_device_keys(self, query_list, include_all_devices=False):
         """Fetch a list of device keys.
         Args:
@@ -42,13 +72,19 @@ class EndToEndKeyStore(SQLBaseStore):
             dict containing "key_json", "device_display_name".
         """
         if not query_list:
-            return {}
+            defer.returnValue({})
 
-        return self.runInteraction(
+        results = yield self.runInteraction(
             "get_e2e_device_keys", self._get_e2e_device_keys_txn,
             query_list, include_all_devices,
         )
 
+        for user_id, device_keys in results.iteritems():
+            for device_id, device_info in device_keys.iteritems():
+                device_info["keys"] = json.loads(device_info.pop("key_json"))
+
+        defer.returnValue(results)
+
     def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
         query_clauses = []
         query_params = []
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 768e0a4451..0fdcf29085 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -131,7 +131,7 @@ class RoomMemberStore(SQLBaseStore):
         with self._stream_id_gen.get_next() as stream_ordering:
             yield self.runInteraction("locally_reject_invite", f, stream_ordering)
 
-    @cached(max_entries=5000)
+    @cached(max_entries=100000, iterable=True)
     def get_users_in_room(self, room_id):
         def f(txn):
 
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index bfa6294250..84ce492a2c 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -33,7 +33,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
     @defer.inlineCallbacks
     def test_key_without_device_name(self):
         now = 1470174257070
-        json = '{ "key": "value" }'
+        json = {"key": "value"}
 
         yield self.store.store_device(
             "user", "device", None
@@ -47,14 +47,14 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
         self.assertDictContainsSubset({
-            "key_json": json,
+            "keys": json,
             "device_display_name": None,
         }, dev)
 
     @defer.inlineCallbacks
     def test_get_key_with_device_name(self):
         now = 1470174257070
-        json = '{ "key": "value" }'
+        json = {"key": "value"}
 
         yield self.store.set_e2e_device_keys(
             "user", "device", now, json)
@@ -67,7 +67,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
         self.assertDictContainsSubset({
-            "key_json": json,
+            "keys": json,
             "device_display_name": "display_name",
         }, dev)