diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 71f62036c0..9a828231c4 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -30,16 +30,16 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import (
- Cache,
- SQLBaseStore,
- db_to_json,
- make_in_list_sql_clause,
-)
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.database import Database
from synapse.types import get_verify_key_from_cross_signing_key
from synapse.util import batch_iter
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import (
+ Cache,
+ cached,
+ cachedInlineCallbacks,
+ cachedList,
+)
logger = logging.getLogger(__name__)
@@ -61,7 +61,7 @@ class DeviceWorkerStore(SQLBaseStore):
Raises:
StoreError: if the device is not found
"""
- return self._simple_select_one(
+ return self.db.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@@ -80,7 +80,7 @@ class DeviceWorkerStore(SQLBaseStore):
containing "device_id", "user_id" and "display_name" for each
device.
"""
- devices = yield self._simple_select_list(
+ devices = yield self.db.simple_select_list(
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@@ -122,7 +122,7 @@ class DeviceWorkerStore(SQLBaseStore):
# consider the device update to be too large, and simply skip the
# stream_id; the rationale being that such a large device list update
# is likely an error.
- updates = yield self.runInteraction(
+ updates = yield self.db.runInteraction(
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
@@ -283,7 +283,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""
devices = (
- yield self.runInteraction(
+ yield self.db.runInteraction(
"_get_e2e_device_keys_txn",
self._get_e2e_device_keys_txn,
query_map.keys(),
@@ -340,12 +340,12 @@ class DeviceWorkerStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0]
- return self.runInteraction("get_last_device_update_for_remote_user", f)
+ return self.db.runInteraction("get_last_device_update_for_remote_user", f)
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
destination,
@@ -399,7 +399,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""
with self._device_list_id_gen.get_next() as stream_id:
- yield self.runInteraction(
+ yield self.db.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
from_user_id,
@@ -414,7 +414,7 @@ class DeviceWorkerStore(SQLBaseStore):
from_user_id,
stream_id,
)
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
"user_signature_stream",
values={
@@ -466,7 +466,7 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id):
- content = yield self._simple_select_one_onecol(
+ content = yield self.db.simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="content",
@@ -476,7 +476,7 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
- devices = yield self._simple_select_list(
+ devices = yield self.db.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
@@ -492,7 +492,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
(stream_id, devices)
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn,
user_id,
@@ -565,7 +565,7 @@ class DeviceWorkerStore(SQLBaseStore):
return changes
- return self.runInteraction(
+ return self.db.runInteraction(
"get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
)
@@ -584,7 +584,7 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT DISTINCT user_ids FROM user_signature_stream
WHERE from_user_id = ? AND stream_id > ?
"""
- rows = yield self._execute(
+ rows = yield self.db.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key
)
return set(user for row in rows for user in json.loads(row[0]))
@@ -605,7 +605,7 @@ class DeviceWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination
"""
- return self._execute(
+ return self.db.execute(
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
)
@@ -614,7 +614,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
@@ -628,7 +628,7 @@ class DeviceWorkerStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_device_list_last_stream_id_for_remotes(self, user_ids):
- rows = yield self._simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
@@ -642,11 +642,11 @@ class DeviceWorkerStore(SQLBaseStore):
return results
-class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
- def __init__(self, db_conn, hs):
- super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs)
+class DeviceBackgroundUpdateStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"device_lists_stream_idx",
index_name="device_lists_stream_user_id",
table="device_lists_stream",
@@ -654,7 +654,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
)
# create a unique index on device_lists_remote_cache
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"device_lists_remote_cache_unique_idx",
index_name="device_lists_remote_cache_unique_id",
table="device_lists_remote_cache",
@@ -663,7 +663,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
)
# And one on device_lists_remote_extremeties
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"device_lists_remote_extremeties_unique_idx",
index_name="device_lists_remote_extremeties_unique_idx",
table="device_lists_remote_extremeties",
@@ -672,7 +672,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
)
# once they complete, we can remove the old non-unique indexes.
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES,
self._drop_device_list_streams_non_unique_indexes,
)
@@ -685,14 +685,16 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close()
- yield self.runWithConnection(f)
- yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
+ yield self.db.runWithConnection(f)
+ yield self.db.updates._end_background_update(
+ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
+ )
return 1
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
- def __init__(self, db_conn, hs):
- super(DeviceStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(DeviceStore, self).__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
@@ -722,7 +724,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return False
try:
- inserted = yield self._simple_insert(
+ inserted = yield self.db.simple_insert(
"devices",
values={
"user_id": user_id,
@@ -736,7 +738,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not inserted:
# if the device already exists, check if it's a real device, or
# if the device ID is reserved by something else
- hidden = yield self._simple_select_one_onecol(
+ hidden = yield self.db.simple_select_one_onecol(
"devices",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="hidden",
@@ -771,7 +773,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns:
defer.Deferred
"""
- yield self._simple_delete_one(
+ yield self.db.simple_delete_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
desc="delete_device",
@@ -789,7 +791,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns:
defer.Deferred
"""
- yield self._simple_delete_many(
+ yield self.db.simple_delete_many(
table="devices",
column="device_id",
iterable=device_ids,
@@ -818,7 +820,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
updates["display_name"] = new_display_name
if not updates:
return defer.succeed(None)
- return self._simple_update_one(
+ return self.db.simple_update_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
updatevalues=updates,
@@ -829,7 +831,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
- yield self._simple_delete(
+ yield self.db.simple_delete(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
desc="mark_remote_user_device_list_as_unsubscribed",
@@ -853,7 +855,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns:
Deferred[None]
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id,
@@ -866,7 +868,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self, txn, user_id, device_id, content, stream_id
):
if content.get("deleted"):
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -874,7 +876,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
else:
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -890,7 +892,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@@ -914,7 +916,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns:
Deferred[None]
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id,
@@ -923,11 +925,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
values=[
@@ -946,7 +948,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@@ -962,7 +964,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
(if any) should be poked.
"""
with self._device_list_id_gen.get_next() as stream_id:
- yield self.runInteraction(
+ yield self.db.runInteraction(
"add_device_change_to_streams",
self._add_device_change_txn,
user_id,
@@ -995,7 +997,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
[(user_id, device_id, stream_id) for device_id in device_ids],
)
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
@@ -1006,7 +1008,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context = get_active_span_text_map()
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
values=[
@@ -1069,7 +1071,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return run_as_background_process(
"prune_old_outbound_device_pokes",
- self.runInteraction,
+ self.db.runInteraction,
"_prune_old_outbound_device_pokes",
_prune_txn,
)
|