diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
new file mode 100644
index 0000000000..88a7aadfc6
--- /dev/null
+++ b/synapse/storage/databases/main/devices.py
@@ -0,0 +1,1311 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import List, Optional, Set, Tuple
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.api.errors import Codes, StoreError
+from synapse.logging.opentracing import (
+ get_active_span_text_map,
+ set_tag,
+ trace,
+ whitelisted_homeserver,
+)
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingTransaction,
+ make_tuple_comparison_clause,
+)
+from synapse.types import Collection, get_verify_key_from_cross_signing_key
+from synapse.util.caches.descriptors import (
+ Cache,
+ cached,
+ cachedInlineCallbacks,
+ cachedList,
+)
+from synapse.util.iterutils import batch_iter
+from synapse.util.stringutils import shortstr
+
+logger = logging.getLogger(__name__)
+
+DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
+ "drop_device_list_streams_non_unique_indexes"
+)
+
+BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
+
+
+class DeviceWorkerStore(SQLBaseStore):
+ def get_device(self, user_id, device_id):
+ """Retrieve a device. Only returns devices that are not marked as
+ hidden.
+
+ Args:
+ user_id (str): The ID of the user which owns the device
+ device_id (str): The ID of the device to retrieve
+ Returns:
+ defer.Deferred for a dict containing the device information
+ Raises:
+ StoreError: if the device is not found
+ """
+ return self.db_pool.simple_select_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_device",
+ )
+
+ @defer.inlineCallbacks
+ def get_devices_by_user(self, user_id):
+ """Retrieve all of a user's registered devices. Only returns devices
+ that are not marked as hidden.
+
+ Args:
+ user_id (str):
+ Returns:
+ defer.Deferred: resolves to a dict from device_id to a dict
+ containing "device_id", "user_id" and "display_name" for each
+ device.
+ """
+ devices = yield self.db_pool.simple_select_list(
+ table="devices",
+ keyvalues={"user_id": user_id, "hidden": False},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_devices_by_user",
+ )
+
+ return {d["device_id"]: d for d in devices}
+
+ @trace
+ @defer.inlineCallbacks
+ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
+ """Get a stream of device updates to send to the given remote server.
+
+ Args:
+ destination (str): The host the device updates are intended for
+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+ limit (int): Maximum number of device updates to return
+ Returns:
+ Deferred[tuple[int, list[tuple[string,dict]]]]:
+ current stream id (ie, the stream id of the last update included in the
+ response), and the list of updates, where each update is a pair of EDU
+ type and EDU contents
+ """
+ now_stream_id = self._device_list_id_gen.get_current_token()
+
+ has_changed = self._device_list_federation_stream_cache.has_entity_changed(
+ destination, int(from_stream_id)
+ )
+ if not has_changed:
+ return now_stream_id, []
+
+ updates = yield self.db_pool.runInteraction(
+ "get_device_updates_by_remote",
+ self._get_device_updates_by_remote_txn,
+ destination,
+ from_stream_id,
+ now_stream_id,
+ limit,
+ )
+
+ # Return an empty list if there are no updates
+ if not updates:
+ return now_stream_id, []
+
+ # get the cross-signing keys of the users in the list, so that we can
+ # determine which of the device changes were cross-signing keys
+ users = {r[0] for r in updates}
+ master_key_by_user = {}
+ self_signing_key_by_user = {}
+ for user in users:
+ cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ master_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "device_id": verify_key.version,
+ }
+
+ cross_signing_key = yield self.get_e2e_cross_signing_key(
+ user, "self_signing"
+ )
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ self_signing_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "device_id": verify_key.version,
+ }
+
+ # Perform the equivalent of a GROUP BY
+ #
+ # Iterate through the updates list and copy non-duplicate
+ # (user_id, device_id) entries into a map, with the value being
+ # the max stream_id across each set of duplicate entries
+ #
+ # maps (user_id, device_id) -> (stream_id, opentracing_context)
+ #
+ # opentracing_context contains the opentracing metadata for the request
+ # that created the poke
+ #
+ # The most recent request's opentracing_context is used as the
+ # context which created the Edu.
+
+ query_map = {}
+ cross_signing_keys_by_user = {}
+ for user_id, device_id, update_stream_id, update_context in updates:
+ if (
+ user_id in master_key_by_user
+ and device_id == master_key_by_user[user_id]["device_id"]
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["master_key"] = master_key_by_user[user_id]["key_info"]
+ elif (
+ user_id in self_signing_key_by_user
+ and device_id == self_signing_key_by_user[user_id]["device_id"]
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["self_signing_key"] = self_signing_key_by_user[user_id][
+ "key_info"
+ ]
+ else:
+ key = (user_id, device_id)
+
+ previous_update_stream_id, _ = query_map.get(key, (0, None))
+
+ if update_stream_id > previous_update_stream_id:
+ query_map[key] = (update_stream_id, update_context)
+
+ results = yield self._get_device_update_edus_by_remote(
+ destination, from_stream_id, query_map
+ )
+
+ # add the updated cross-signing keys to the results list
+ for user_id, result in cross_signing_keys_by_user.items():
+ result["user_id"] = user_id
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ results.append(("org.matrix.signing_key_update", result))
+
+ return now_stream_id, results
+
+ def _get_device_updates_by_remote_txn(
+ self, txn, destination, from_stream_id, now_stream_id, limit
+ ):
+ """Return device update information for a given remote destination
+
+ Args:
+ txn (LoggingTransaction): The transaction to execute
+ destination (str): The host the device updates are intended for
+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+ now_stream_id (int): The maximum stream_id to filter updates by, inclusive
+ limit (int): Maximum number of device updates to return
+
+ Returns:
+ List: List of device updates
+ """
+ # get the list of device updates that need to be sent
+ sql = """
+ SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
+ WHERE destination = ? AND ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id
+ LIMIT ?
+ """
+ txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
+
+ return list(txn)
+
+ @defer.inlineCallbacks
+ def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map):
+ """Returns a list of device update EDUs as well as E2EE keys
+
+ Args:
+ destination (str): The host the device updates are intended for
+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+ query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
+ user_id/device_id to update stream_id and the relevent json-encoded
+ opentracing context
+
+ Returns:
+ List[Dict]: List of objects representing an device update EDU
+
+ """
+ devices = (
+ yield self.db_pool.runInteraction(
+ "_get_e2e_device_keys_txn",
+ self._get_e2e_device_keys_txn,
+ query_map.keys(),
+ include_all_devices=True,
+ include_deleted_devices=True,
+ )
+ if query_map
+ else {}
+ )
+
+ results = []
+ for user_id, user_devices in devices.items():
+ # The prev_id for the first row is always the last row before
+ # `from_stream_id`
+ prev_id = yield self._get_last_device_update_for_remote_user(
+ destination, user_id, from_stream_id
+ )
+
+ # make sure we go through the devices in stream order
+ device_ids = sorted(
+ user_devices.keys(), key=lambda i: query_map[(user_id, i)][0],
+ )
+
+ for device_id in device_ids:
+ device = user_devices[device_id]
+ stream_id, opentracing_context = query_map[(user_id, device_id)]
+ result = {
+ "user_id": user_id,
+ "device_id": device_id,
+ "prev_id": [prev_id] if prev_id else [],
+ "stream_id": stream_id,
+ "org.matrix.opentracing_context": opentracing_context,
+ }
+
+ prev_id = stream_id
+
+ if device is not None:
+ key_json = device.get("key_json", None)
+ if key_json:
+ result["keys"] = db_to_json(key_json)
+
+ if "signatures" in device:
+ for sig_user_id, sigs in device["signatures"].items():
+ result["keys"].setdefault("signatures", {}).setdefault(
+ sig_user_id, {}
+ ).update(sigs)
+
+ device_display_name = device.get("device_display_name", None)
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+ else:
+ result["deleted"] = True
+
+ results.append(("m.device_list_update", result))
+
+ return results
+
+ def _get_last_device_update_for_remote_user(
+ self, destination, user_id, from_stream_id
+ ):
+ def f(txn):
+ prev_sent_id_sql = """
+ SELECT coalesce(max(stream_id), 0) as stream_id
+ FROM device_lists_outbound_last_success
+ WHERE destination = ? AND user_id = ? AND stream_id <= ?
+ """
+ txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
+ rows = txn.fetchall()
+ return rows[0][0]
+
+ return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
+
+ def mark_as_sent_devices_by_remote(self, destination, stream_id):
+ """Mark that updates have successfully been sent to the destination.
+ """
+ return self.db_pool.runInteraction(
+ "mark_as_sent_devices_by_remote",
+ self._mark_as_sent_devices_by_remote_txn,
+ destination,
+ stream_id,
+ )
+
+ def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
+ # We update the device_lists_outbound_last_success with the successfully
+ # poked users.
+ sql = """
+ SELECT user_id, coalesce(max(o.stream_id), 0)
+ FROM device_lists_outbound_pokes as o
+ WHERE destination = ? AND o.stream_id <= ?
+ GROUP BY user_id
+ """
+ txn.execute(sql, (destination, stream_id))
+ rows = txn.fetchall()
+
+ self.db_pool.simple_upsert_many_txn(
+ txn=txn,
+ table="device_lists_outbound_last_success",
+ key_names=("destination", "user_id"),
+ key_values=((destination, user_id) for user_id, _ in rows),
+ value_names=("stream_id",),
+ value_values=((stream_id,) for _, stream_id in rows),
+ )
+
+ # Delete all sent outbound pokes
+ sql = """
+ DELETE FROM device_lists_outbound_pokes
+ WHERE destination = ? AND stream_id <= ?
+ """
+ txn.execute(sql, (destination, stream_id))
+
+ @defer.inlineCallbacks
+ def add_user_signature_change_to_streams(self, from_user_id, user_ids):
+ """Persist that a user has made new signatures
+
+ Args:
+ from_user_id (str): the user who made the signatures
+ user_ids (list[str]): the users who were signed
+ """
+
+ with self._device_list_id_gen.get_next() as stream_id:
+ yield self.db_pool.runInteraction(
+ "add_user_sig_change_to_streams",
+ self._add_user_signature_change_txn,
+ from_user_id,
+ user_ids,
+ stream_id,
+ )
+ return stream_id
+
+ def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id):
+ txn.call_after(
+ self._user_signature_stream_cache.entity_has_changed,
+ from_user_id,
+ stream_id,
+ )
+ self.db_pool.simple_insert_txn(
+ txn,
+ "user_signature_stream",
+ values={
+ "stream_id": stream_id,
+ "from_user_id": from_user_id,
+ "user_ids": json.dumps(user_ids),
+ },
+ )
+
+ def get_device_stream_token(self):
+ return self._device_list_id_gen.get_current_token()
+
+ @trace
+ @defer.inlineCallbacks
+ def get_user_devices_from_cache(self, query_list):
+ """Get the devices (and keys if any) for remote users from the cache.
+
+ Args:
+ query_list(list): List of (user_id, device_ids), if device_ids is
+ falsey then return all device ids for that user.
+
+ Returns:
+ (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
+ a set of user_ids and results_map is a mapping of
+ user_id -> device_id -> device_info
+ """
+ user_ids = {user_id for user_id, _ in query_list}
+ user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+
+ # We go and check if any of the users need to have their device lists
+ # resynced. If they do then we remove them from the cached list.
+ users_needing_resync = yield self.get_user_ids_requiring_device_list_resync(
+ user_ids
+ )
+ user_ids_in_cache = {
+ user_id for user_id, stream_id in user_map.items() if stream_id
+ } - users_needing_resync
+ user_ids_not_in_cache = user_ids - user_ids_in_cache
+
+ results = {}
+ for user_id, device_id in query_list:
+ if user_id not in user_ids_in_cache:
+ continue
+
+ if device_id:
+ device = yield self._get_cached_user_device(user_id, device_id)
+ results.setdefault(user_id, {})[device_id] = device
+ else:
+ results[user_id] = yield self.get_cached_devices_for_user(user_id)
+
+ set_tag("in_cache", results)
+ set_tag("not_in_cache", user_ids_not_in_cache)
+
+ return user_ids_not_in_cache, results
+
+ @cachedInlineCallbacks(num_args=2, tree=True)
+ def _get_cached_user_device(self, user_id, device_id):
+ content = yield self.db_pool.simple_select_one_onecol(
+ table="device_lists_remote_cache",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcol="content",
+ desc="_get_cached_user_device",
+ )
+ return db_to_json(content)
+
+ @cachedInlineCallbacks()
+ def get_cached_devices_for_user(self, user_id):
+ devices = yield self.db_pool.simple_select_list(
+ table="device_lists_remote_cache",
+ keyvalues={"user_id": user_id},
+ retcols=("device_id", "content"),
+ desc="get_cached_devices_for_user",
+ )
+ return {
+ device["device_id"]: db_to_json(device["content"]) for device in devices
+ }
+
+ def get_devices_with_keys_by_user(self, user_id):
+ """Get all devices (with any device keys) for a user
+
+ Returns:
+ (stream_id, devices)
+ """
+ return self.db_pool.runInteraction(
+ "get_devices_with_keys_by_user",
+ self._get_devices_with_keys_by_user_txn,
+ user_id,
+ )
+
+ def _get_devices_with_keys_by_user_txn(self, txn, user_id):
+ now_stream_id = self._device_list_id_gen.get_current_token()
+
+ devices = self._get_e2e_device_keys_txn(
+ txn, [(user_id, None)], include_all_devices=True
+ )
+
+ if devices:
+ user_devices = devices[user_id]
+ results = []
+ for device_id, device in user_devices.items():
+ result = {"device_id": device_id}
+
+ key_json = device.get("key_json", None)
+ if key_json:
+ result["keys"] = db_to_json(key_json)
+
+ if "signatures" in device:
+ for sig_user_id, sigs in device["signatures"].items():
+ result["keys"].setdefault("signatures", {}).setdefault(
+ sig_user_id, {}
+ ).update(sigs)
+
+ device_display_name = device.get("device_display_name", None)
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+
+ results.append(result)
+
+ return now_stream_id, results
+
+ return now_stream_id, []
+
+ def get_users_whose_devices_changed(self, from_key, user_ids):
+ """Get set of users whose devices have changed since `from_key` that
+ are in the given list of user_ids.
+
+ Args:
+ from_key (str): The device lists stream token
+ user_ids (Iterable[str])
+
+ Returns:
+ Deferred[set[str]]: The set of user_ids whose devices have changed
+ since `from_key`
+ """
+ from_key = int(from_key)
+
+ # Get set of users who *may* have changed. Users not in the returned
+ # list have definitely not changed.
+ to_check = self._device_list_stream_cache.get_entities_changed(
+ user_ids, from_key
+ )
+
+ if not to_check:
+ return defer.succeed(set())
+
+ def _get_users_whose_devices_changed_txn(txn):
+ changes = set()
+
+ sql = """
+ SELECT DISTINCT user_id FROM device_lists_stream
+ WHERE stream_id > ?
+ AND
+ """
+
+ for chunk in batch_iter(to_check, 100):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "user_id", chunk
+ )
+ txn.execute(sql + clause, (from_key,) + tuple(args))
+ changes.update(user_id for user_id, in txn)
+
+ return changes
+
+ return self.db_pool.runInteraction(
+ "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
+ )
+
+ @defer.inlineCallbacks
+ def get_users_whose_signatures_changed(self, user_id, from_key):
+ """Get the users who have new cross-signing signatures made by `user_id` since
+ `from_key`.
+
+ Args:
+ user_id (str): the user who made the signatures
+ from_key (str): The device lists stream token
+ """
+ from_key = int(from_key)
+ if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
+ sql = """
+ SELECT DISTINCT user_ids FROM user_signature_stream
+ WHERE from_user_id = ? AND stream_id > ?
+ """
+ rows = yield self.db_pool.execute(
+ "get_users_whose_signatures_changed", None, sql, user_id, from_key
+ )
+ return {user for row in rows for user in db_to_json(row[0])}
+ else:
+ return set()
+
+ async def get_all_device_list_changes_for_remotes(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for device lists replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+
+ if last_id == current_id:
+ return [], current_id, False
+
+ def _get_all_device_list_changes_for_remotes(txn):
+ # This query Does The Right Thing where it'll correctly apply the
+ # bounds to the inner queries.
+ sql = """
+ SELECT stream_id, entity FROM (
+ SELECT stream_id, user_id AS entity FROM device_lists_stream
+ UNION ALL
+ SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+ ) AS e
+ WHERE ? < stream_id AND stream_id <= ?
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_id, current_id, limit))
+ updates = [(row[0], row[1:]) for row in txn]
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return await self.db_pool.runInteraction(
+ "get_all_device_list_changes_for_remotes",
+ _get_all_device_list_changes_for_remotes,
+ )
+
+ @cached(max_entries=10000)
+ def get_device_list_last_stream_id_for_remote(self, user_id):
+ """Get the last stream_id we got for a user. May be None if we haven't
+ got any information for them.
+ """
+ return self.db_pool.simple_select_one_onecol(
+ table="device_lists_remote_extremeties",
+ keyvalues={"user_id": user_id},
+ retcol="stream_id",
+ desc="get_device_list_last_stream_id_for_remote",
+ allow_none=True,
+ )
+
+ @cachedList(
+ cached_method_name="get_device_list_last_stream_id_for_remote",
+ list_name="user_ids",
+ inlineCallbacks=True,
+ )
+ def get_device_list_last_stream_id_for_remotes(self, user_ids):
+ rows = yield self.db_pool.simple_select_many_batch(
+ table="device_lists_remote_extremeties",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id", "stream_id"),
+ desc="get_device_list_last_stream_id_for_remotes",
+ )
+
+ results = {user_id: None for user_id in user_ids}
+ results.update({row["user_id"]: row["stream_id"] for row in rows})
+
+ return results
+
+ @defer.inlineCallbacks
+ def get_user_ids_requiring_device_list_resync(
+ self, user_ids: Optional[Collection[str]] = None,
+ ) -> Set[str]:
+ """Given a list of remote users return the list of users that we
+ should resync the device lists for. If None is given instead of a list,
+ return every user that we should resync the device lists for.
+
+ Returns:
+ The IDs of users whose device lists need resync.
+ """
+ if user_ids:
+ rows = yield self.db_pool.simple_select_many_batch(
+ table="device_lists_remote_resync",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id",),
+ desc="get_user_ids_requiring_device_list_resync_with_iterable",
+ )
+ else:
+ rows = yield self.db_pool.simple_select_list(
+ table="device_lists_remote_resync",
+ keyvalues=None,
+ retcols=("user_id",),
+ desc="get_user_ids_requiring_device_list_resync",
+ )
+
+ return {row["user_id"] for row in rows}
+
+ def mark_remote_user_device_cache_as_stale(self, user_id: str):
+ """Records that the server has reason to believe the cache of the devices
+ for the remote users is out of date.
+ """
+ return self.db_pool.simple_upsert(
+ table="device_lists_remote_resync",
+ keyvalues={"user_id": user_id},
+ values={},
+ insertion_values={"added_ts": self._clock.time_msec()},
+ desc="make_remote_user_device_cache_as_stale",
+ )
+
+ def mark_remote_user_device_list_as_unsubscribed(self, user_id):
+ """Mark that we no longer track device lists for remote user.
+ """
+
+ def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="device_lists_remote_extremeties",
+ keyvalues={"user_id": user_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
+ )
+
+ return self.db_pool.runInteraction(
+ "mark_remote_user_device_list_as_unsubscribed",
+ _mark_remote_user_device_list_as_unsubscribed_txn,
+ )
+
+
+class DeviceBackgroundUpdateStore(SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_index_update(
+ "device_lists_stream_idx",
+ index_name="device_lists_stream_user_id",
+ table="device_lists_stream",
+ columns=["user_id", "device_id"],
+ )
+
+ # create a unique index on device_lists_remote_cache
+ self.db_pool.updates.register_background_index_update(
+ "device_lists_remote_cache_unique_idx",
+ index_name="device_lists_remote_cache_unique_id",
+ table="device_lists_remote_cache",
+ columns=["user_id", "device_id"],
+ unique=True,
+ )
+
+ # And one on device_lists_remote_extremeties
+ self.db_pool.updates.register_background_index_update(
+ "device_lists_remote_extremeties_unique_idx",
+ index_name="device_lists_remote_extremeties_unique_idx",
+ table="device_lists_remote_extremeties",
+ columns=["user_id"],
+ unique=True,
+ )
+
+ # once they complete, we can remove the old non-unique indexes.
+ self.db_pool.updates.register_background_update_handler(
+ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES,
+ self._drop_device_list_streams_non_unique_indexes,
+ )
+
+ # clear out duplicate device list outbound pokes
+ self.db_pool.updates.register_background_update_handler(
+ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
+ )
+
+ # a pair of background updates that were added during the 1.14 release cycle,
+ # but replaced with 58/06dlols_unique_idx.py
+ self.db_pool.updates.register_noop_background_update(
+ "device_lists_outbound_last_success_unique_idx",
+ )
+ self.db_pool.updates.register_noop_background_update(
+ "drop_device_lists_outbound_last_success_non_unique_idx",
+ )
+
+ @defer.inlineCallbacks
+ def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
+ def f(conn):
+ txn = conn.cursor()
+ txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
+ txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
+ txn.close()
+
+ yield self.db_pool.runWithConnection(f)
+ yield self.db_pool.updates._end_background_update(
+ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
+ )
+ return 1
+
+ async def _remove_duplicate_outbound_pokes(self, progress, batch_size):
+ # for some reason, we have accumulated duplicate entries in
+ # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
+ # efficient.
+ #
+ # For each duplicate, we delete all the existing rows and put one back.
+
+ KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
+ last_row = progress.get(
+ "last_row",
+ {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
+ )
+
+ def _txn(txn):
+ clause, args = make_tuple_comparison_clause(
+ self.db_pool.engine, [(x, last_row[x]) for x in KEY_COLS]
+ )
+ sql = """
+ SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
+ FROM device_lists_outbound_pokes
+ WHERE %s
+ GROUP BY %s
+ HAVING count(*) > 1
+ ORDER BY %s
+ LIMIT ?
+ """ % (
+ clause, # WHERE
+ ",".join(KEY_COLS), # GROUP BY
+ ",".join(KEY_COLS), # ORDER BY
+ )
+ txn.execute(sql, args + [batch_size])
+ rows = self.db_pool.cursor_to_dict(txn)
+
+ row = None
+ for row in rows:
+ self.db_pool.simple_delete_txn(
+ txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
+ )
+
+ row["sent"] = False
+ self.db_pool.simple_insert_txn(
+ txn, "device_lists_outbound_pokes", row,
+ )
+
+ if row:
+ self.db_pool.updates._background_update_progress_txn(
+ txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
+ )
+
+ return len(rows)
+
+ rows = await self.db_pool.runInteraction(
+ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn
+ )
+
+ if not rows:
+ await self.db_pool.updates._end_background_update(
+ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES
+ )
+
+ return rows
+
+
+class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super(DeviceStore, self).__init__(database, db_conn, hs)
+
+ # Map of (user_id, device_id) -> bool. If there is an entry that implies
+ # the device exists.
+ self.device_id_exists_cache = Cache(
+ name="device_id_exists", keylen=2, max_entries=10000
+ )
+
+ self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
+
+ @defer.inlineCallbacks
+ def store_device(self, user_id, device_id, initial_device_display_name):
+ """Ensure the given device is known; add it to the store if not
+
+ Args:
+ user_id (str): id of user associated with the device
+ device_id (str): id of device
+ initial_device_display_name (str): initial displayname of the
+ device. Ignored if device exists.
+ Returns:
+ defer.Deferred: boolean whether the device was inserted or an
+ existing device existed with that ID.
+ Raises:
+ StoreError: if the device is already in use
+ """
+ key = (user_id, device_id)
+ if self.device_id_exists_cache.get(key, None):
+ return False
+
+ try:
+ inserted = yield self.db_pool.simple_insert(
+ "devices",
+ values={
+ "user_id": user_id,
+ "device_id": device_id,
+ "display_name": initial_device_display_name,
+ "hidden": False,
+ },
+ desc="store_device",
+ or_ignore=True,
+ )
+ if not inserted:
+ # if the device already exists, check if it's a real device, or
+ # if the device ID is reserved by something else
+ hidden = yield self.db_pool.simple_select_one_onecol(
+ "devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcol="hidden",
+ )
+ if hidden:
+ raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
+ self.device_id_exists_cache.prefill(key, True)
+ return inserted
+ except StoreError:
+ raise
+ except Exception as e:
+ logger.error(
+ "store_device with device_id=%s(%r) user_id=%s(%r)"
+ " display_name=%s(%r) failed: %s",
+ type(device_id).__name__,
+ device_id,
+ type(user_id).__name__,
+ user_id,
+ type(initial_device_display_name).__name__,
+ initial_device_display_name,
+ e,
+ )
+ raise StoreError(500, "Problem storing device.")
+
+ @defer.inlineCallbacks
+ def delete_device(self, user_id, device_id):
+ """Delete a device.
+
+ Args:
+ user_id (str): The ID of the user which owns the device
+ device_id (str): The ID of the device to delete
+ Returns:
+ defer.Deferred
+ """
+ yield self.db_pool.simple_delete_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
+ desc="delete_device",
+ )
+
+ self.device_id_exists_cache.invalidate((user_id, device_id))
+
+ @defer.inlineCallbacks
+ def delete_devices(self, user_id, device_ids):
+ """Deletes several devices.
+
+ Args:
+ user_id (str): The ID of the user which owns the devices
+ device_ids (list): The IDs of the devices to delete
+ Returns:
+ defer.Deferred
+ """
+ yield self.db_pool.simple_delete_many(
+ table="devices",
+ column="device_id",
+ iterable=device_ids,
+ keyvalues={"user_id": user_id, "hidden": False},
+ desc="delete_devices",
+ )
+ for device_id in device_ids:
+ self.device_id_exists_cache.invalidate((user_id, device_id))
+
+ def update_device(self, user_id, device_id, new_display_name=None):
+ """Update a device. Only updates the device if it is not marked as
+ hidden.
+
+ Args:
+ user_id (str): The ID of the user which owns the device
+ device_id (str): The ID of the device to update
+ new_display_name (str|None): new displayname for device; None
+ to leave unchanged
+ Raises:
+ StoreError: if the device is not found
+ Returns:
+ defer.Deferred
+ """
+ updates = {}
+ if new_display_name is not None:
+ updates["display_name"] = new_display_name
+ if not updates:
+ return defer.succeed(None)
+ return self.db_pool.simple_update_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
+ updatevalues=updates,
+ desc="update_device",
+ )
+
+ def update_remote_device_list_cache_entry(
+ self, user_id, device_id, content, stream_id
+ ):
+ """Updates a single device in the cache of a remote user's devicelist.
+
+ Note: assumes that we are the only thread that can be updating this user's
+ device list.
+
+ Args:
+ user_id (str): User to update device list for
+ device_id (str): ID of decivice being updated
+ content (dict): new data on this device
+ stream_id (int): the version of the device list
+
+ Returns:
+ Deferred[None]
+ """
+ return self.db_pool.runInteraction(
+ "update_remote_device_list_cache_entry",
+ self._update_remote_device_list_cache_entry_txn,
+ user_id,
+ device_id,
+ content,
+ stream_id,
+ )
+
+ def _update_remote_device_list_cache_entry_txn(
+ self, txn, user_id, device_id, content, stream_id
+ ):
+ if content.get("deleted"):
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="device_lists_remote_cache",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+
+ txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
+ else:
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="device_lists_remote_cache",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ values={"content": json.dumps(content)},
+ # we don't need to lock, because we assume we are the only thread
+ # updating this user's devices.
+ lock=False,
+ )
+
+ txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
+ txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
+ txn.call_after(
+ self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
+ )
+
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="device_lists_remote_extremeties",
+ keyvalues={"user_id": user_id},
+ values={"stream_id": stream_id},
+ # again, we can assume we are the only thread updating this user's
+ # extremity.
+ lock=False,
+ )
+
+ def update_remote_device_list_cache(self, user_id, devices, stream_id):
+ """Replace the entire cache of the remote user's devices.
+
+ Note: assumes that we are the only thread that can be updating this user's
+ device list.
+
+ Args:
+ user_id (str): User to update device list for
+ devices (list[dict]): list of device objects supplied over federation
+ stream_id (int): the version of the device list
+
+ Returns:
+ Deferred[None]
+ """
+ return self.db_pool.runInteraction(
+ "update_remote_device_list_cache",
+ self._update_remote_device_list_cache_txn,
+ user_id,
+ devices,
+ stream_id,
+ )
+
+ def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
+ self.db_pool.simple_delete_txn(
+ txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="device_lists_remote_cache",
+ values=[
+ {
+ "user_id": user_id,
+ "device_id": content["device_id"],
+ "content": json.dumps(content),
+ }
+ for content in devices
+ ],
+ )
+
+ txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
+ txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
+ txn.call_after(
+ self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
+ )
+
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="device_lists_remote_extremeties",
+ keyvalues={"user_id": user_id},
+ values={"stream_id": stream_id},
+ # we don't need to lock, because we can assume we are the only thread
+ # updating this user's extremity.
+ lock=False,
+ )
+
+ # If we're replacing the remote user's device list cache presumably
+ # we've done a full resync, so we remove the entry that says we need
+ # to resync
+ self.db_pool.simple_delete_txn(
+ txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
+ )
+
+ @defer.inlineCallbacks
+ def add_device_change_to_streams(self, user_id, device_ids, hosts):
+ """Persist that a user's devices have been updated, and which hosts
+ (if any) should be poked.
+ """
+ if not device_ids:
+ return
+
+ with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
+ yield self.db_pool.runInteraction(
+ "add_device_change_to_stream",
+ self._add_device_change_to_stream_txn,
+ user_id,
+ device_ids,
+ stream_ids,
+ )
+
+ if not hosts:
+ return stream_ids[-1]
+
+ context = get_active_span_text_map()
+ with self._device_list_id_gen.get_next_mult(
+ len(hosts) * len(device_ids)
+ ) as stream_ids:
+ yield self.db_pool.runInteraction(
+ "add_device_outbound_poke_to_stream",
+ self._add_device_outbound_poke_to_stream_txn,
+ user_id,
+ device_ids,
+ hosts,
+ stream_ids,
+ context,
+ )
+
+ return stream_ids[-1]
+
+ def _add_device_change_to_stream_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_ids: Collection[str],
+ stream_ids: List[str],
+ ):
+ txn.call_after(
+ self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
+ )
+
+ min_stream_id = stream_ids[0]
+
+ # Delete older entries in the table, as we really only care about
+ # when the latest change happened.
+ txn.executemany(
+ """
+ DELETE FROM device_lists_stream
+ WHERE user_id = ? AND device_id = ? AND stream_id < ?
+ """,
+ [(user_id, device_id, min_stream_id) for device_id in device_ids],
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="device_lists_stream",
+ values=[
+ {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
+ for stream_id, device_id in zip(stream_ids, device_ids)
+ ],
+ )
+
+ def _add_device_outbound_poke_to_stream_txn(
+ self, txn, user_id, device_ids, hosts, stream_ids, context,
+ ):
+ for host in hosts:
+ txn.call_after(
+ self._device_list_federation_stream_cache.entity_has_changed,
+ host,
+ stream_ids[-1],
+ )
+
+ now = self._clock.time_msec()
+ next_stream_id = iter(stream_ids)
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="device_lists_outbound_pokes",
+ values=[
+ {
+ "destination": destination,
+ "stream_id": next(next_stream_id),
+ "user_id": user_id,
+ "device_id": device_id,
+ "sent": False,
+ "ts": now,
+ "opentracing_context": json.dumps(context)
+ if whitelisted_homeserver(destination)
+ else "{}",
+ }
+ for destination in hosts
+ for device_id in device_ids
+ ],
+ )
+
+ def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000):
+ """Delete old entries out of the device_lists_outbound_pokes to ensure
+ that we don't fill up due to dead servers.
+
+ Normally, we try to send device updates as a delta since a previous known point:
+ this is done by setting the prev_id in the m.device_list_update EDU. However,
+ for that to work, we have to have a complete record of each change to
+ each device, which can add up to quite a lot of data.
+
+ An alternative mechanism is that, if the remote server sees that it has missed
+ an entry in the stream_id sequence for a given user, it will request a full
+ list of that user's devices. Hence, we can reduce the amount of data we have to
+ store (and transmit in some future transaction), by clearing almost everything
+ for a given destination out of the database, and having the remote server
+ resync.
+
+ All we need to do is make sure we keep at least one row for each
+ (user, destination) pair, to remind us to send a m.device_list_update EDU for
+ that user when the destination comes back. It doesn't matter which device
+ we keep.
+ """
+ yesterday = self._clock.time_msec() - prune_age
+
+ def _prune_txn(txn):
+ # look for (user, destination) pairs which have an update older than
+ # the cutoff.
+ #
+ # For each pair, we also need to know the most recent stream_id, and
+ # an arbitrary device_id at that stream_id.
+ select_sql = """
+ SELECT
+ dlop1.destination,
+ dlop1.user_id,
+ MAX(dlop1.stream_id) AS stream_id,
+ (SELECT MIN(dlop2.device_id) AS device_id FROM
+ device_lists_outbound_pokes dlop2
+ WHERE dlop2.destination = dlop1.destination AND
+ dlop2.user_id=dlop1.user_id AND
+ dlop2.stream_id=MAX(dlop1.stream_id)
+ )
+ FROM device_lists_outbound_pokes dlop1
+ GROUP BY destination, user_id
+ HAVING min(ts) < ? AND count(*) > 1
+ """
+
+ txn.execute(select_sql, (yesterday,))
+ rows = txn.fetchall()
+
+ if not rows:
+ return
+
+ logger.info(
+ "Pruning old outbound device list updates for %i users/destinations: %s",
+ len(rows),
+ shortstr((row[0], row[1]) for row in rows),
+ )
+
+ # we want to keep the update with the highest stream_id for each user.
+ #
+ # there might be more than one update (with different device_ids) with the
+ # same stream_id, so we also delete all but one rows with the max stream id.
+ delete_sql = """
+ DELETE FROM device_lists_outbound_pokes
+ WHERE destination = ? AND user_id = ? AND (
+ stream_id < ? OR
+ (stream_id = ? AND device_id != ?)
+ )
+ """
+ count = 0
+ for (destination, user_id, stream_id, device_id) in rows:
+ txn.execute(
+ delete_sql, (destination, user_id, stream_id, stream_id, device_id)
+ )
+ count += txn.rowcount
+
+ # Since we've deleted unsent deltas, we need to remove the entry
+ # of last successful sent so that the prev_ids are correctly set.
+ sql = """
+ DELETE FROM device_lists_outbound_last_success
+ WHERE destination = ? AND user_id = ?
+ """
+ txn.executemany(sql, ((row[0], row[1]) for row in rows))
+
+ logger.info("Pruned %d device list outbound pokes", count)
+
+ return run_as_background_process(
+ "prune_old_outbound_device_pokes",
+ self.db_pool.runInteraction,
+ "_prune_old_outbound_device_pokes",
+ _prune_txn,
+ )
|