diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 6464520386..9ccc66e589 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -15,7 +15,17 @@
# limitations under the License.
import abc
import logging
-from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+)
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@@ -38,6 +48,9 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
@@ -48,7 +61,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
@@ -414,7 +427,7 @@ class DeviceWorkerStore(SQLBaseStore):
user_ids: the users who were signed
Returns:
- THe new stream ID.
+ The new stream ID.
"""
async with self._device_list_id_gen.get_next() as stream_id:
@@ -915,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):
class DeviceBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
@@ -1047,7 +1060,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
@@ -1121,19 +1134,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
raise StoreError(500, "Problem storing device.")
async def delete_device(self, user_id: str, device_id: str) -> None:
- """Delete a device.
+ """Delete a device and its device_inbox.
Args:
user_id: The ID of the user which owns the device
device_id: The ID of the device to delete
"""
- await self.db_pool.simple_delete_one(
- table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
- desc="delete_device",
- )
- self.device_id_exists_cache.invalidate((user_id, device_id))
+ await self.delete_devices(user_id, [device_id])
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
"""Deletes several devices.
@@ -1142,13 +1150,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: The ID of the user which owns the devices
device_ids: The IDs of the devices to delete
"""
- await self.db_pool.simple_delete_many(
- table="devices",
- column="device_id",
- iterable=device_ids,
- keyvalues={"user_id": user_id, "hidden": False},
- desc="delete_devices",
- )
+
+ def _delete_devices_txn(txn: LoggingTransaction) -> None:
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="devices",
+ column="device_id",
+ values=device_ids,
+ keyvalues={"user_id": user_id, "hidden": False},
+ )
+
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="device_inbox",
+ column="device_id",
+ values=device_ids,
+ keyvalues={"user_id": user_id},
+ )
+
+ await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
@@ -1302,7 +1322,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
async def add_device_change_to_streams(
self, user_id: str, device_ids: Collection[str], hosts: List[str]
- ):
+ ) -> int:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
|