diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 9628e2ff75..8ee3119db2 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -138,6 +138,89 @@ class DeviceStore(SQLBaseStore):
defer.returnValue({d["device_id"]: d for d in devices})
+ def get_device_list_remote_extremity(self, user_id):
+ return self._simple_select_one_onecol(
+ table="device_lists_remote_extremeties",
+ keyvalues={"user_id": user_id},
+ retcol="stream_id",
+ desc="get_device_list_remote_extremity",
+ allow_none=True,
+ )
+
+ def update_remote_device_list_cache_entry(self, user_id, device_id, content,
+ stream_id):
+ return self.runInteraction(
+ "update_remote_device_list_cache_entry",
+ self._update_remote_device_list_cache_entry_txn,
+ user_id, device_id, content, stream_id,
+ )
+
+ def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
+ content, stream_id):
+ self._simple_upsert_txn(
+ txn,
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ values={
+ "content": json.dumps(content),
+ }
+ )
+
+ self._simple_upsert_txn(
+ txn,
+ table="device_lists_remote_extremeties",
+ keyvalues={
+ "user_id": user_id,
+ },
+ values={
+ "stream_id": stream_id,
+ }
+ )
+
+ def update_remote_device_list_cache(self, user_id, devices, stream_id):
+ return self.runInteraction(
+ "update_remote_device_list_cache",
+ self._update_remote_device_list_cache_txn,
+ user_id, devices, stream_id,
+ )
+
+ def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
+ stream_id):
+ self._simple_delete_txn(
+ txn,
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ },
+ )
+
+ self._simple_insert_many_txn(
+ txn,
+ table="device_lists_remote_cache",
+ values=[
+ {
+ "user_id": user_id,
+ "device_id": content["device_id"],
+ "content": json.dumps(content),
+ }
+ for content in devices
+ ]
+ )
+
+ self._simple_upsert_txn(
+ txn,
+ table="device_lists_remote_extremeties",
+ keyvalues={
+ "user_id": user_id,
+ },
+ values={
+ "stream_id": stream_id,
+ }
+ )
+
def get_devices_by_remote(self, destination, from_stream_id):
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -184,7 +267,7 @@ class DeviceStore(SQLBaseStore):
txn.execute(prev_sent_id_sql, (destination, user_id, True))
rows = txn.fetchall()
prev_id = rows[0][0]
- for device_id, result in user_devices.iteritems():
+ for device_id, device in user_devices.iteritems():
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
@@ -195,10 +278,10 @@ class DeviceStore(SQLBaseStore):
prev_id = stream_id
- key_json = result.get("key_json", None)
+ key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
- device_display_name = result.get("device_display_name", None)
+ device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
@@ -206,6 +289,96 @@ class DeviceStore(SQLBaseStore):
return (now_stream_id, results)
+ def get_user_devices_from_cache(self, query_list):
+ return self.runInteraction(
+ "get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
+ query_list,
+ )
+
+ def _get_user_devices_from_cache_txn(self, txn, query_list):
+ user_ids = {user_id for user_id, _ in query_list}
+
+ user_ids_in_cache = set()
+ for user_id in user_ids:
+ stream_ids = self._simple_select_onecol_txn(
+ txn,
+ table="device_lists_remote_extremeties",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcol="stream_id",
+ )
+ if stream_ids:
+ user_ids_in_cache.add(user_id)
+
+ user_ids_not_in_cache = user_ids - user_ids_in_cache
+
+ results = {}
+ for user_id, device_id in query_list:
+ if user_id not in user_ids_in_cache:
+ continue
+
+ if device_id:
+ content = self._simple_select_one_onecol_txn(
+ txn,
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ retcol="content",
+ )
+ results.setdefault(user_id, {})[device_id] = json.loads(content)
+ else:
+ devices = self._simple_select_list_txn(
+ txn,
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcols=("device_id", "content"),
+ )
+ results[user_id] = {
+ device["device_id"]: json.loads(device["content"])
+ for device in devices
+ }
+ user_ids_in_cache.discard(user_id)
+
+ return user_ids_not_in_cache, results
+
+ def get_devices_with_keys_by_user(self, user_id):
+ return self.runInteraction(
+ "get_devices_with_keys_by_user",
+ self._get_devices_with_keys_by_user_txn, user_id,
+ )
+
+ def _get_devices_with_keys_by_user_txn(self, txn, user_id):
+ now_stream_id = self._device_list_id_gen.get_current_token()
+
+ devices = self._get_e2e_device_keys_txn(
+ txn, [(user_id, None)], include_all_devices=True
+ )
+
+ for user_id, user_devices in devices.iteritems():
+ results = []
+ for device_id, device in user_devices.iteritems():
+ result = {
+ "device_id": device_id,
+ }
+
+ key_json = device.get("key_json", None)
+ if key_json:
+ result["keys"] = json.loads(key_json)
+ device_display_name = device.get("device_display_name", None)
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+
+ results.append(result)
+
+ return now_stream_id, results
+
+ return now_stream_id, []
+
def mark_as_sent_devices_by_remote(self, destination, stream_id):
return self.runInteraction(
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
@@ -242,17 +415,17 @@ class DeviceStore(SQLBaseStore):
defer.returnValue(set(row["user_id"] for row in rows))
@defer.inlineCallbacks
- def add_device_change_to_streams(self, user_id, device_id, hosts):
+ def add_device_change_to_streams(self, user_id, device_ids, hosts):
# device_lists_stream
# device_lists_outbound_pokes
with self._device_list_id_gen.get_next() as stream_id:
yield self.runInteraction(
"add_device_change_to_streams", self._add_device_change_txn,
- user_id, device_id, hosts, stream_id,
+ user_id, device_ids, hosts, stream_id,
)
defer.returnValue(stream_id)
- def _add_device_change_txn(self, txn, user_id, device_id, hosts, stream_id):
+ def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
txn.call_after(
self._device_list_stream_cache.entity_has_changed,
user_id, stream_id,
@@ -263,14 +436,17 @@ class DeviceStore(SQLBaseStore):
host, stream_id,
)
- self._simple_insert_txn(
+ self._simple_insert_many_txn(
txn,
table="device_lists_stream",
- values={
- "stream_id": stream_id,
- "user_id": user_id,
- "device_id": device_id,
- }
+ values=[
+ {
+ "stream_id": stream_id,
+ "user_id": user_id,
+ "device_id": device_id,
+ }
+ for device_id in device_ids
+ ]
)
self._simple_insert_many_txn(
@@ -285,6 +461,7 @@ class DeviceStore(SQLBaseStore):
"sent": False,
}
for destination in hosts
+ for device_id in device_ids
]
)
|