diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index ecdab34e7d..e716dc1437 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -22,11 +22,10 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import Cache, SQLBaseStore, db_to_json
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
-from ._base import Cache, db_to_json
-
logger = logging.getLogger(__name__)
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
@@ -34,7 +33,343 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
)
-class DeviceStore(BackgroundUpdateStore):
+class DeviceWorkerStore(SQLBaseStore):
+ def get_device(self, user_id, device_id):
+ """Retrieve a device.
+
+ Args:
+ user_id (str): The ID of the user which owns the device
+ device_id (str): The ID of the device to retrieve
+ Returns:
+ defer.Deferred for a dict containing the device information
+ Raises:
+ StoreError: if the device is not found
+ """
+ return self._simple_select_one(
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_device",
+ )
+
+ @defer.inlineCallbacks
+ def get_devices_by_user(self, user_id):
+ """Retrieve all of a user's registered devices.
+
+ Args:
+ user_id (str):
+ Returns:
+ defer.Deferred: resolves to a dict from device_id to a dict
+ containing "device_id", "user_id" and "display_name" for each
+ device.
+ """
+ devices = yield self._simple_select_list(
+ table="devices",
+ keyvalues={"user_id": user_id},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_devices_by_user"
+ )
+
+ defer.returnValue({d["device_id"]: d for d in devices})
+
+ def get_devices_by_remote(self, destination, from_stream_id):
+ """Get stream of updates to send to remote servers
+
+ Returns:
+ (int, list[dict]): current stream id and list of updates
+ """
+ now_stream_id = self._device_list_id_gen.get_current_token()
+
+ has_changed = self._device_list_federation_stream_cache.has_entity_changed(
+ destination, int(from_stream_id)
+ )
+ if not has_changed:
+ return (now_stream_id, [])
+
+ return self.runInteraction(
+ "get_devices_by_remote", self._get_devices_by_remote_txn,
+ destination, from_stream_id, now_stream_id,
+ )
+
+ def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
+ now_stream_id):
+ sql = """
+ SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
+ WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
+ GROUP BY user_id, device_id
+ LIMIT 20
+ """
+ txn.execute(
+ sql, (destination, from_stream_id, now_stream_id, False)
+ )
+
+ # maps (user_id, device_id) -> stream_id
+ query_map = {(r[0], r[1]): r[2] for r in txn}
+ if not query_map:
+ return (now_stream_id, [])
+
+ if len(query_map) >= 20:
+ now_stream_id = max(stream_id for stream_id in itervalues(query_map))
+
+ devices = self._get_e2e_device_keys_txn(
+ txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
+ )
+
+ prev_sent_id_sql = """
+ SELECT coalesce(max(stream_id), 0) as stream_id
+ FROM device_lists_outbound_last_success
+ WHERE destination = ? AND user_id = ? AND stream_id <= ?
+ """
+
+ results = []
+ for user_id, user_devices in iteritems(devices):
+ # The prev_id for the first row is always the last row before
+ # `from_stream_id`
+ txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
+ rows = txn.fetchall()
+ prev_id = rows[0][0]
+ for device_id, device in iteritems(user_devices):
+ stream_id = query_map[(user_id, device_id)]
+ result = {
+ "user_id": user_id,
+ "device_id": device_id,
+ "prev_id": [prev_id] if prev_id else [],
+ "stream_id": stream_id,
+ }
+
+ prev_id = stream_id
+
+ if device is not None:
+ key_json = device.get("key_json", None)
+ if key_json:
+ result["keys"] = db_to_json(key_json)
+ device_display_name = device.get("device_display_name", None)
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+ else:
+ result["deleted"] = True
+
+ results.append(result)
+
+ return (now_stream_id, results)
+
+ def mark_as_sent_devices_by_remote(self, destination, stream_id):
+ """Mark that updates have successfully been sent to the destination.
+ """
+ return self.runInteraction(
+ "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
+ destination, stream_id,
+ )
+
+ def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
+ # We update the device_lists_outbound_last_success with the successfully
+ # poked users. We do the join to see which users need to be inserted and
+ # which updated.
+ sql = """
+ SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+ FROM device_lists_outbound_pokes as o
+ LEFT JOIN device_lists_outbound_last_success as s
+ USING (destination, user_id)
+ WHERE destination = ? AND o.stream_id <= ?
+ GROUP BY user_id
+ """
+ txn.execute(sql, (destination, stream_id,))
+ rows = txn.fetchall()
+
+ sql = """
+ UPDATE device_lists_outbound_last_success
+ SET stream_id = ?
+ WHERE destination = ? AND user_id = ?
+ """
+ txn.executemany(
+ sql, ((row[1], destination, row[0],) for row in rows if row[2])
+ )
+
+ sql = """
+ INSERT INTO device_lists_outbound_last_success
+ (destination, user_id, stream_id) VALUES (?, ?, ?)
+ """
+ txn.executemany(
+ sql, ((destination, row[0], row[1],) for row in rows if not row[2])
+ )
+
+ # Delete all sent outbound pokes
+ sql = """
+ DELETE FROM device_lists_outbound_pokes
+ WHERE destination = ? AND stream_id <= ?
+ """
+ txn.execute(sql, (destination, stream_id,))
+
+ def get_device_stream_token(self):
+ return self._device_list_id_gen.get_current_token()
+
+ @defer.inlineCallbacks
+ def get_user_devices_from_cache(self, query_list):
+ """Get the devices (and keys if any) for remote users from the cache.
+
+ Args:
+ query_list(list): List of (user_id, device_ids), if device_ids is
+ falsey then return all device ids for that user.
+
+ Returns:
+ (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
+ a set of user_ids and results_map is a mapping of
+ user_id -> device_id -> device_info
+ """
+ user_ids = set(user_id for user_id, _ in query_list)
+ user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+ user_ids_in_cache = set(
+ user_id for user_id, stream_id in user_map.items() if stream_id
+ )
+ user_ids_not_in_cache = user_ids - user_ids_in_cache
+
+ results = {}
+ for user_id, device_id in query_list:
+ if user_id not in user_ids_in_cache:
+ continue
+
+ if device_id:
+ device = yield self._get_cached_user_device(user_id, device_id)
+ results.setdefault(user_id, {})[device_id] = device
+ else:
+ results[user_id] = yield self._get_cached_devices_for_user(user_id)
+
+ defer.returnValue((user_ids_not_in_cache, results))
+
+ @cachedInlineCallbacks(num_args=2, tree=True)
+ def _get_cached_user_device(self, user_id, device_id):
+ content = yield self._simple_select_one_onecol(
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ retcol="content",
+ desc="_get_cached_user_device",
+ )
+ defer.returnValue(db_to_json(content))
+
+ @cachedInlineCallbacks()
+ def _get_cached_devices_for_user(self, user_id):
+ devices = yield self._simple_select_list(
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcols=("device_id", "content"),
+ desc="_get_cached_devices_for_user",
+ )
+ defer.returnValue({
+ device["device_id"]: db_to_json(device["content"])
+ for device in devices
+ })
+
+ def get_devices_with_keys_by_user(self, user_id):
+ """Get all devices (with any device keys) for a user
+
+ Returns:
+ (stream_id, devices)
+ """
+ return self.runInteraction(
+ "get_devices_with_keys_by_user",
+ self._get_devices_with_keys_by_user_txn, user_id,
+ )
+
+ def _get_devices_with_keys_by_user_txn(self, txn, user_id):
+ now_stream_id = self._device_list_id_gen.get_current_token()
+
+ devices = self._get_e2e_device_keys_txn(
+ txn, [(user_id, None)], include_all_devices=True
+ )
+
+ if devices:
+ user_devices = devices[user_id]
+ results = []
+ for device_id, device in iteritems(user_devices):
+ result = {
+ "device_id": device_id,
+ }
+
+ key_json = device.get("key_json", None)
+ if key_json:
+ result["keys"] = db_to_json(key_json)
+ device_display_name = device.get("device_display_name", None)
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+
+ results.append(result)
+
+ return now_stream_id, results
+
+ return now_stream_id, []
+
+ @defer.inlineCallbacks
+ def get_user_whose_devices_changed(self, from_key):
+ """Get set of users whose devices have changed since `from_key`.
+ """
+ from_key = int(from_key)
+ changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
+ if changed is not None:
+ defer.returnValue(set(changed))
+
+ sql = """
+ SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
+ """
+ rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
+ defer.returnValue(set(row[0] for row in rows))
+
+ def get_all_device_list_changes_for_remotes(self, from_key, to_key):
+ """Return a list of `(stream_id, user_id, destination)` which is the
+ combined list of changes to devices, and which destinations need to be
+ poked. `destination` may be None if no destinations need to be poked.
+ """
+ # We do a group by here as there can be a large number of duplicate
+ # entries, since we throw away device IDs.
+ sql = """
+ SELECT MAX(stream_id) AS stream_id, user_id, destination
+ FROM device_lists_stream
+ LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+ WHERE ? < stream_id AND stream_id <= ?
+ GROUP BY user_id, destination
+ """
+ return self._execute(
+ "get_all_device_list_changes_for_remotes", None,
+ sql, from_key, to_key
+ )
+
+ @cached(max_entries=10000)
+ def get_device_list_last_stream_id_for_remote(self, user_id):
+ """Get the last stream_id we got for a user. May be None if we haven't
+ got any information for them.
+ """
+ return self._simple_select_one_onecol(
+ table="device_lists_remote_extremeties",
+ keyvalues={"user_id": user_id},
+ retcol="stream_id",
+ desc="get_device_list_last_stream_id_for_remote",
+ allow_none=True,
+ )
+
+ @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
+ list_name="user_ids", inlineCallbacks=True)
+ def get_device_list_last_stream_id_for_remotes(self, user_ids):
+ rows = yield self._simple_select_many_batch(
+ table="device_lists_remote_extremeties",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id", "stream_id",),
+ desc="get_device_list_last_stream_id_for_remotes",
+ )
+
+ results = {user_id: None for user_id in user_ids}
+ results.update({
+ row["user_id"]: row["stream_id"] for row in rows
+ })
+
+ defer.returnValue(results)
+
+
+class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(DeviceStore, self).__init__(db_conn, hs)
@@ -121,24 +456,6 @@ class DeviceStore(BackgroundUpdateStore):
initial_device_display_name, e)
raise StoreError(500, "Problem storing device.")
- def get_device(self, user_id, device_id):
- """Retrieve a device.
-
- Args:
- user_id (str): The ID of the user which owns the device
- device_id (str): The ID of the device to retrieve
- Returns:
- defer.Deferred for a dict containing the device information
- Raises:
- StoreError: if the device is not found
- """
- return self._simple_select_one(
- table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
- retcols=("user_id", "device_id", "display_name"),
- desc="get_device",
- )
-
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
"""Delete a device.
@@ -203,57 +520,6 @@ class DeviceStore(BackgroundUpdateStore):
)
@defer.inlineCallbacks
- def get_devices_by_user(self, user_id):
- """Retrieve all of a user's registered devices.
-
- Args:
- user_id (str):
- Returns:
- defer.Deferred: resolves to a dict from device_id to a dict
- containing "device_id", "user_id" and "display_name" for each
- device.
- """
- devices = yield self._simple_select_list(
- table="devices",
- keyvalues={"user_id": user_id},
- retcols=("user_id", "device_id", "display_name"),
- desc="get_devices_by_user"
- )
-
- defer.returnValue({d["device_id"]: d for d in devices})
-
- @cached(max_entries=10000)
- def get_device_list_last_stream_id_for_remote(self, user_id):
- """Get the last stream_id we got for a user. May be None if we haven't
- got any information for them.
- """
- return self._simple_select_one_onecol(
- table="device_lists_remote_extremeties",
- keyvalues={"user_id": user_id},
- retcol="stream_id",
- desc="get_device_list_remote_extremity",
- allow_none=True,
- )
-
- @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
- list_name="user_ids", inlineCallbacks=True)
- def get_device_list_last_stream_id_for_remotes(self, user_ids):
- rows = yield self._simple_select_many_batch(
- table="device_lists_remote_extremeties",
- column="user_id",
- iterable=user_ids,
- retcols=("user_id", "stream_id",),
- desc="get_user_devices_from_cache",
- )
-
- results = {user_id: None for user_id in user_ids}
- results.update({
- row["user_id"]: row["stream_id"] for row in rows
- })
-
- defer.returnValue(results)
-
- @defer.inlineCallbacks
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
@@ -405,268 +671,6 @@ class DeviceStore(BackgroundUpdateStore):
lock=False,
)
- def get_devices_by_remote(self, destination, from_stream_id):
- """Get stream of updates to send to remote servers
-
- Returns:
- (int, list[dict]): current stream id and list of updates
- """
- now_stream_id = self._device_list_id_gen.get_current_token()
-
- has_changed = self._device_list_federation_stream_cache.has_entity_changed(
- destination, int(from_stream_id)
- )
- if not has_changed:
- return (now_stream_id, [])
-
- return self.runInteraction(
- "get_devices_by_remote", self._get_devices_by_remote_txn,
- destination, from_stream_id, now_stream_id,
- )
-
- def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
- now_stream_id):
- sql = """
- SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
- WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
- GROUP BY user_id, device_id
- LIMIT 20
- """
- txn.execute(
- sql, (destination, from_stream_id, now_stream_id, False)
- )
-
- # maps (user_id, device_id) -> stream_id
- query_map = {(r[0], r[1]): r[2] for r in txn}
- if not query_map:
- return (now_stream_id, [])
-
- if len(query_map) >= 20:
- now_stream_id = max(stream_id for stream_id in itervalues(query_map))
-
- devices = self._get_e2e_device_keys_txn(
- txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
- )
-
- prev_sent_id_sql = """
- SELECT coalesce(max(stream_id), 0) as stream_id
- FROM device_lists_outbound_last_success
- WHERE destination = ? AND user_id = ? AND stream_id <= ?
- """
-
- results = []
- for user_id, user_devices in iteritems(devices):
- # The prev_id for the first row is always the last row before
- # `from_stream_id`
- txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
- rows = txn.fetchall()
- prev_id = rows[0][0]
- for device_id, device in iteritems(user_devices):
- stream_id = query_map[(user_id, device_id)]
- result = {
- "user_id": user_id,
- "device_id": device_id,
- "prev_id": [prev_id] if prev_id else [],
- "stream_id": stream_id,
- }
-
- prev_id = stream_id
-
- if device is not None:
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = db_to_json(key_json)
- device_display_name = device.get("device_display_name", None)
- if device_display_name:
- result["device_display_name"] = device_display_name
- else:
- result["deleted"] = True
-
- results.append(result)
-
- return (now_stream_id, results)
-
- @defer.inlineCallbacks
- def get_user_devices_from_cache(self, query_list):
- """Get the devices (and keys if any) for remote users from the cache.
-
- Args:
- query_list(list): List of (user_id, device_ids), if device_ids is
- falsey then return all device ids for that user.
-
- Returns:
- (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
- a set of user_ids and results_map is a mapping of
- user_id -> device_id -> device_info
- """
- user_ids = set(user_id for user_id, _ in query_list)
- user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
- user_ids_in_cache = set(
- user_id for user_id, stream_id in user_map.items() if stream_id
- )
- user_ids_not_in_cache = user_ids - user_ids_in_cache
-
- results = {}
- for user_id, device_id in query_list:
- if user_id not in user_ids_in_cache:
- continue
-
- if device_id:
- device = yield self._get_cached_user_device(user_id, device_id)
- results.setdefault(user_id, {})[device_id] = device
- else:
- results[user_id] = yield self._get_cached_devices_for_user(user_id)
-
- defer.returnValue((user_ids_not_in_cache, results))
-
- @cachedInlineCallbacks(num_args=2, tree=True)
- def _get_cached_user_device(self, user_id, device_id):
- content = yield self._simple_select_one_onecol(
- table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- retcol="content",
- desc="_get_cached_user_device",
- )
- defer.returnValue(db_to_json(content))
-
- @cachedInlineCallbacks()
- def _get_cached_devices_for_user(self, user_id):
- devices = yield self._simple_select_list(
- table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- },
- retcols=("device_id", "content"),
- desc="_get_cached_devices_for_user",
- )
- defer.returnValue({
- device["device_id"]: db_to_json(device["content"])
- for device in devices
- })
-
- def get_devices_with_keys_by_user(self, user_id):
- """Get all devices (with any device keys) for a user
-
- Returns:
- (stream_id, devices)
- """
- return self.runInteraction(
- "get_devices_with_keys_by_user",
- self._get_devices_with_keys_by_user_txn, user_id,
- )
-
- def _get_devices_with_keys_by_user_txn(self, txn, user_id):
- now_stream_id = self._device_list_id_gen.get_current_token()
-
- devices = self._get_e2e_device_keys_txn(
- txn, [(user_id, None)], include_all_devices=True
- )
-
- if devices:
- user_devices = devices[user_id]
- results = []
- for device_id, device in iteritems(user_devices):
- result = {
- "device_id": device_id,
- }
-
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = db_to_json(key_json)
- device_display_name = device.get("device_display_name", None)
- if device_display_name:
- result["device_display_name"] = device_display_name
-
- results.append(result)
-
- return now_stream_id, results
-
- return now_stream_id, []
-
- def mark_as_sent_devices_by_remote(self, destination, stream_id):
- """Mark that updates have successfully been sent to the destination.
- """
- return self.runInteraction(
- "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
- destination, stream_id,
- )
-
- def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
- # We update the device_lists_outbound_last_success with the successfully
- # poked users. We do the join to see which users need to be inserted and
- # which updated.
- sql = """
- SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
- FROM device_lists_outbound_pokes as o
- LEFT JOIN device_lists_outbound_last_success as s
- USING (destination, user_id)
- WHERE destination = ? AND o.stream_id <= ?
- GROUP BY user_id
- """
- txn.execute(sql, (destination, stream_id,))
- rows = txn.fetchall()
-
- sql = """
- UPDATE device_lists_outbound_last_success
- SET stream_id = ?
- WHERE destination = ? AND user_id = ?
- """
- txn.executemany(
- sql, ((row[1], destination, row[0],) for row in rows if row[2])
- )
-
- sql = """
- INSERT INTO device_lists_outbound_last_success
- (destination, user_id, stream_id) VALUES (?, ?, ?)
- """
- txn.executemany(
- sql, ((destination, row[0], row[1],) for row in rows if not row[2])
- )
-
- # Delete all sent outbound pokes
- sql = """
- DELETE FROM device_lists_outbound_pokes
- WHERE destination = ? AND stream_id <= ?
- """
- txn.execute(sql, (destination, stream_id,))
-
- @defer.inlineCallbacks
- def get_user_whose_devices_changed(self, from_key):
- """Get set of users whose devices have changed since `from_key`.
- """
- from_key = int(from_key)
- changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
- if changed is not None:
- defer.returnValue(set(changed))
-
- sql = """
- SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
- """
- rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
- defer.returnValue(set(row[0] for row in rows))
-
- def get_all_device_list_changes_for_remotes(self, from_key, to_key):
- """Return a list of `(stream_id, user_id, destination)` which is the
- combined list of changes to devices, and which destinations need to be
- poked. `destination` may be None if no destinations need to be poked.
- """
- # We do a group by here as there can be a large number of duplicate
- # entries, since we throw away device IDs.
- sql = """
- SELECT MAX(stream_id) AS stream_id, user_id, destination
- FROM device_lists_stream
- LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
- WHERE ? < stream_id AND stream_id <= ?
- GROUP BY user_id, destination
- """
- return self._execute(
- "get_all_device_list_changes_for_remotes", None,
- sql, from_key, to_key
- )
-
@defer.inlineCallbacks
def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts
@@ -732,9 +736,6 @@ class DeviceStore(BackgroundUpdateStore):
]
)
- def get_device_stream_token(self):
- return self._device_list_id_gen.get_current_token()
-
def _prune_old_outbound_device_pokes(self):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
that we don't fill up due to dead servers. We keep one entry per
|