summary refs log tree commit diff
path: root/synapse/storage/devices.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/devices.py201
1 files changed, 189 insertions, 12 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 9628e2ff75..8ee3119db2 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -138,6 +138,89 @@ class DeviceStore(SQLBaseStore):
 
         defer.returnValue({d["device_id"]: d for d in devices})
 
+    def get_device_list_remote_extremity(self, user_id):
+        return self._simple_select_one_onecol(
+            table="device_lists_remote_extremeties",
+            keyvalues={"user_id": user_id},
+            retcol="stream_id",
+            desc="get_device_list_remote_extremity",
+            allow_none=True,
+        )
+
+    def update_remote_device_list_cache_entry(self, user_id, device_id, content,
+                                              stream_id):
+        return self.runInteraction(
+            "update_remote_device_list_cache_entry",
+            self._update_remote_device_list_cache_entry_txn,
+            user_id, device_id, content, stream_id,
+        )
+
+    def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
+                                                   content, stream_id):
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+            },
+            values={
+                "content": json.dumps(content),
+            }
+        )
+
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_extremeties",
+            keyvalues={
+                "user_id": user_id,
+            },
+            values={
+                "stream_id": stream_id,
+            }
+        )
+
+    def update_remote_device_list_cache(self, user_id, devices, stream_id):
+        return self.runInteraction(
+            "update_remote_device_list_cache",
+            self._update_remote_device_list_cache_txn,
+            user_id, devices, stream_id,
+        )
+
+    def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
+                                             stream_id):
+        self._simple_delete_txn(
+            txn,
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+            },
+        )
+
+        self._simple_insert_many_txn(
+            txn,
+            table="device_lists_remote_cache",
+            values=[
+                {
+                    "user_id": user_id,
+                    "device_id": content["device_id"],
+                    "content": json.dumps(content),
+                }
+                for content in devices
+            ]
+        )
+
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_extremeties",
+            keyvalues={
+                "user_id": user_id,
+            },
+            values={
+                "stream_id": stream_id,
+            }
+        )
+
     def get_devices_by_remote(self, destination, from_stream_id):
         now_stream_id = self._device_list_id_gen.get_current_token()
 
@@ -184,7 +267,7 @@ class DeviceStore(SQLBaseStore):
             txn.execute(prev_sent_id_sql, (destination, user_id, True))
             rows = txn.fetchall()
             prev_id = rows[0][0]
-            for device_id, result in user_devices.iteritems():
+            for device_id, device in user_devices.iteritems():
                 stream_id = query_map[(user_id, device_id)]
                 result = {
                     "user_id": user_id,
@@ -195,10 +278,10 @@ class DeviceStore(SQLBaseStore):
 
                 prev_id = stream_id
 
-                key_json = result.get("key_json", None)
+                key_json = device.get("key_json", None)
                 if key_json:
                     result["keys"] = json.loads(key_json)
-                device_display_name = result.get("device_display_name", None)
+                device_display_name = device.get("device_display_name", None)
                 if device_display_name:
                     result["device_display_name"] = device_display_name
 
@@ -206,6 +289,96 @@ class DeviceStore(SQLBaseStore):
 
         return (now_stream_id, results)
 
+    def get_user_devices_from_cache(self, query_list):
+        return self.runInteraction(
+            "get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
+            query_list,
+        )
+
+    def _get_user_devices_from_cache_txn(self, txn, query_list):
+        user_ids = {user_id for user_id, _ in query_list}
+
+        user_ids_in_cache = set()
+        for user_id in user_ids:
+            stream_ids = self._simple_select_onecol_txn(
+                txn,
+                table="device_lists_remote_extremeties",
+                keyvalues={
+                    "user_id": user_id,
+                },
+                retcol="stream_id",
+            )
+            if stream_ids:
+                user_ids_in_cache.add(user_id)
+
+        user_ids_not_in_cache = user_ids - user_ids_in_cache
+
+        results = {}
+        for user_id, device_id in query_list:
+            if user_id not in user_ids_in_cache:
+                continue
+
+            if device_id:
+                content = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="device_lists_remote_cache",
+                    keyvalues={
+                        "user_id": user_id,
+                        "device_id": device_id,
+                    },
+                    retcol="content",
+                )
+                results.setdefault(user_id, {})[device_id] = json.loads(content)
+            else:
+                devices = self._simple_select_list_txn(
+                    txn,
+                    table="device_lists_remote_cache",
+                    keyvalues={
+                        "user_id": user_id,
+                    },
+                    retcols=("device_id", "content"),
+                )
+                results[user_id] = {
+                    device["device_id"]: json.loads(device["content"])
+                    for device in devices
+                }
+                user_ids_in_cache.discard(user_id)
+
+        return user_ids_not_in_cache, results
+
+    def get_devices_with_keys_by_user(self, user_id):
+        return self.runInteraction(
+            "get_devices_with_keys_by_user",
+            self._get_devices_with_keys_by_user_txn, user_id,
+        )
+
+    def _get_devices_with_keys_by_user_txn(self, txn, user_id):
+        now_stream_id = self._device_list_id_gen.get_current_token()
+
+        devices = self._get_e2e_device_keys_txn(
+            txn, [(user_id, None)], include_all_devices=True
+        )
+
+        for user_id, user_devices in devices.iteritems():
+            results = []
+            for device_id, device in user_devices.iteritems():
+                result = {
+                    "device_id": device_id,
+                }
+
+                key_json = device.get("key_json", None)
+                if key_json:
+                    result["keys"] = json.loads(key_json)
+                device_display_name = device.get("device_display_name", None)
+                if device_display_name:
+                    result["device_display_name"] = device_display_name
+
+                results.append(result)
+
+            return now_stream_id, results
+
+        return now_stream_id, []
+
     def mark_as_sent_devices_by_remote(self, destination, stream_id):
         return self.runInteraction(
             "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
@@ -242,17 +415,17 @@ class DeviceStore(SQLBaseStore):
         defer.returnValue(set(row["user_id"] for row in rows))
 
     @defer.inlineCallbacks
-    def add_device_change_to_streams(self, user_id, device_id, hosts):
+    def add_device_change_to_streams(self, user_id, device_ids, hosts):
         # device_lists_stream
         # device_lists_outbound_pokes
         with self._device_list_id_gen.get_next() as stream_id:
             yield self.runInteraction(
                 "add_device_change_to_streams", self._add_device_change_txn,
-                user_id, device_id, hosts, stream_id,
+                user_id, device_ids, hosts, stream_id,
             )
         defer.returnValue(stream_id)
 
-    def _add_device_change_txn(self, txn, user_id, device_id, hosts, stream_id):
+    def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
         txn.call_after(
             self._device_list_stream_cache.entity_has_changed,
             user_id, stream_id,
@@ -263,14 +436,17 @@ class DeviceStore(SQLBaseStore):
                 host, stream_id,
             )
 
-        self._simple_insert_txn(
+        self._simple_insert_many_txn(
             txn,
             table="device_lists_stream",
-            values={
-                "stream_id": stream_id,
-                "user_id": user_id,
-                "device_id": device_id,
-            }
+            values=[
+                {
+                    "stream_id": stream_id,
+                    "user_id": user_id,
+                    "device_id": device_id,
+                }
+                for device_id in device_ids
+            ]
         )
 
         self._simple_insert_many_txn(
@@ -285,6 +461,7 @@ class DeviceStore(SQLBaseStore):
                     "sent": False,
                 }
                 for destination in hosts
+                for device_id in device_ids
             ]
         )