diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/storage/devices.py | 201 |
1 files changed, 189 insertions, 12 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 9628e2ff75..8ee3119db2 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -138,6 +138,89 @@ class DeviceStore(SQLBaseStore): defer.returnValue({d["device_id"]: d for d in devices}) + def get_device_list_remote_extremity(self, user_id): + return self._simple_select_one_onecol( + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + retcol="stream_id", + desc="get_device_list_remote_extremity", + allow_none=True, + ) + + def update_remote_device_list_cache_entry(self, user_id, device_id, content, + stream_id): + return self.runInteraction( + "update_remote_device_list_cache_entry", + self._update_remote_device_list_cache_entry_txn, + user_id, device_id, content, stream_id, + ) + + def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id, + content, stream_id): + self._simple_upsert_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "content": json.dumps(content), + } + ) + + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + } + ) + + def update_remote_device_list_cache(self, user_id, devices, stream_id): + return self.runInteraction( + "update_remote_device_list_cache", + self._update_remote_device_list_cache_txn, + user_id, devices, stream_id, + ) + + def _update_remote_device_list_cache_txn(self, txn, user_id, devices, + stream_id): + self._simple_delete_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_remote_cache", + values=[ + { + "user_id": user_id, + "device_id": content["device_id"], + "content": json.dumps(content), + } + for content in devices + ] + ) + + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + } + ) + def get_devices_by_remote(self, destination, from_stream_id): now_stream_id = self._device_list_id_gen.get_current_token() @@ -184,7 +267,7 @@ class DeviceStore(SQLBaseStore): txn.execute(prev_sent_id_sql, (destination, user_id, True)) rows = txn.fetchall() prev_id = rows[0][0] - for device_id, result in user_devices.iteritems(): + for device_id, device in user_devices.iteritems(): stream_id = query_map[(user_id, device_id)] result = { "user_id": user_id, @@ -195,10 +278,10 @@ class DeviceStore(SQLBaseStore): prev_id = stream_id - key_json = result.get("key_json", None) + key_json = device.get("key_json", None) if key_json: result["keys"] = json.loads(key_json) - device_display_name = result.get("device_display_name", None) + device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name @@ -206,6 +289,96 @@ class DeviceStore(SQLBaseStore): return (now_stream_id, results) + def get_user_devices_from_cache(self, query_list): + return self.runInteraction( + "get_user_devices_from_cache", self._get_user_devices_from_cache_txn, + query_list, + ) + + def _get_user_devices_from_cache_txn(self, txn, query_list): + user_ids = {user_id for user_id, _ in query_list} + + user_ids_in_cache = set() + for user_id in user_ids: + stream_ids = self._simple_select_onecol_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + retcol="stream_id", + ) + if stream_ids: + user_ids_in_cache.add(user_id) + + user_ids_not_in_cache = user_ids - user_ids_in_cache + + results = {} + for user_id, device_id in query_list: + if user_id not in user_ids_in_cache: + continue + + if device_id: + content = self._simple_select_one_onecol_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="content", + ) + results.setdefault(user_id, {})[device_id] = json.loads(content) + else: + devices = self._simple_select_list_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + retcols=("device_id", "content"), + ) + results[user_id] = { + device["device_id"]: json.loads(device["content"]) + for device in devices + } + user_ids_in_cache.discard(user_id) + + return user_ids_not_in_cache, results + + def get_devices_with_keys_by_user(self, user_id): + return self.runInteraction( + "get_devices_with_keys_by_user", + self._get_devices_with_keys_by_user_txn, user_id, + ) + + def _get_devices_with_keys_by_user_txn(self, txn, user_id): + now_stream_id = self._device_list_id_gen.get_current_token() + + devices = self._get_e2e_device_keys_txn( + txn, [(user_id, None)], include_all_devices=True + ) + + for user_id, user_devices in devices.iteritems(): + results = [] + for device_id, device in user_devices.iteritems(): + result = { + "device_id": device_id, + } + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = json.loads(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + def mark_as_sent_devices_by_remote(self, destination, stream_id): return self.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, @@ -242,17 +415,17 @@ class DeviceStore(SQLBaseStore): defer.returnValue(set(row["user_id"] for row in rows)) @defer.inlineCallbacks - def add_device_change_to_streams(self, user_id, device_id, hosts): + def add_device_change_to_streams(self, user_id, device_ids, hosts): # device_lists_stream # device_lists_outbound_pokes with self._device_list_id_gen.get_next() as stream_id: yield self.runInteraction( "add_device_change_to_streams", self._add_device_change_txn, - user_id, device_id, hosts, stream_id, + user_id, device_ids, hosts, stream_id, ) defer.returnValue(stream_id) - def _add_device_change_txn(self, txn, user_id, device_id, hosts, stream_id): + def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): txn.call_after( self._device_list_stream_cache.entity_has_changed, user_id, stream_id, @@ -263,14 +436,17 @@ class DeviceStore(SQLBaseStore): host, stream_id, ) - self._simple_insert_txn( + self._simple_insert_many_txn( txn, table="device_lists_stream", - values={ - "stream_id": stream_id, - "user_id": user_id, - "device_id": device_id, - } + values=[ + { + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + } + for device_id in device_ids + ] ) self._simple_insert_many_txn( @@ -285,6 +461,7 @@ class DeviceStore(SQLBaseStore): "sent": False, } for destination in hosts + for device_id in device_ids ] ) |