diff --git a/changelog.d/9041.misc b/changelog.d/9041.misc
new file mode 100644
index 0000000000..4952fbe8a2
--- /dev/null
+++ b/changelog.d/9041.misc
@@ -0,0 +1 @@
+Various cleanups to device inbox store.
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 5b045bed02..62b68dd6e9 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -18,7 +18,6 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
-from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -37,13 +36,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
self._device_inbox_id_gen.get_current_token(),
)
- self._last_device_delete_cache = ExpiringCache(
- cache_name="last_device_delete_cache",
- clock=self._clock,
- max_len=10000,
- expiry_ms=30 * 60 * 1000,
- )
-
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(instance_name, token)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index d42faa3f1f..eb72c21155 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -17,7 +17,7 @@ import logging
from typing import List, Tuple
from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
@@ -26,6 +26,18 @@ logger = logging.getLogger(__name__)
class DeviceInboxWorkerStore(SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ # Map of (user_id, device_id) to the last stream_id that has been
+ # deleted up to. This is so that we can no op deletions.
+ self._last_device_delete_cache = ExpiringCache(
+ cache_name="last_device_delete_cache",
+ clock=self._clock,
+ max_len=10000,
+ expiry_ms=30 * 60 * 1000,
+ )
+
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
@@ -310,20 +322,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
- DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
-
- def __init__(self, database: DatabasePool, db_conn, hs):
- super().__init__(database, db_conn, hs)
-
- # Map of (user_id, device_id) to the last stream_id that has been
- # deleted up to. This is so that we can no op deletions.
- self._last_device_delete_cache = ExpiringCache(
- cache_name="last_device_delete_cache",
- clock=self._clock,
- max_len=10000,
- expiry_ms=30 * 60 * 1000,
- )
-
@trace
async def add_messages_to_device_inbox(
self,
@@ -351,16 +349,19 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Add the remote messages to the federation outbox.
# We'll send them to a remote server when we next send a
# federation transaction to that destination.
- sql = (
- "INSERT INTO device_federation_outbox"
- " (destination, stream_id, queued_ts, messages_json)"
- " VALUES (?,?,?,?)"
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="device_federation_outbox",
+ values=[
+ {
+ "destination": destination,
+ "stream_id": stream_id,
+ "queued_ts": now_ms,
+ "messages_json": json_encoder.encode(edu),
+ }
+ for destination, edu in remote_messages_by_destination.items()
+ ],
)
- rows = []
- for destination, edu in remote_messages_by_destination.items():
- edu_json = json_encoder.encode(edu)
- rows.append((destination, stream_id, now_ms, edu_json))
- txn.executemany(sql, rows)
async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
@@ -433,32 +434,37 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
- sql = "SELECT device_id FROM devices WHERE user_id = ?"
- txn.execute(sql, (user_id,))
+ devices = self.db_pool.simple_select_onecol_txn(
+ txn,
+ table="devices",
+ keyvalues={"user_id": user_id},
+ retcol="device_id",
+ )
+
message_json = json_encoder.encode(messages_by_device["*"])
- for row in txn:
+ for device_id in devices:
# Add the message for all devices for this user on this
# server.
- device = row[0]
- messages_json_for_user[device] = message_json
+ messages_json_for_user[device_id] = message_json
else:
if not devices:
continue
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "device_id", devices
+ rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="devices",
+ keyvalues={"user_id": user_id},
+ column="device_id",
+ iterable=devices,
+ retcols=("device_id",),
)
- sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
- # TODO: Maybe this needs to be done in batches if there are
- # too many local devices for a given user.
- txn.execute(sql, [user_id] + list(args))
- for row in txn:
+ for row in rows:
# Only insert into the local inbox if the device exists on
# this server
- device = row[0]
- message_json = json_encoder.encode(messages_by_device[device])
- messages_json_for_user[device] = message_json
+ device_id = row["device_id"]
+ message_json = json_encoder.encode(messages_by_device[device_id])
+ messages_json_for_user[device_id] = message_json
if messages_json_for_user:
local_by_user_then_device[user_id] = messages_json_for_user
@@ -466,14 +472,17 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
if not local_by_user_then_device:
return
- sql = (
- "INSERT INTO device_inbox"
- " (user_id, device_id, stream_id, message_json)"
- " VALUES (?,?,?,?)"
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="device_inbox",
+ values=[
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "stream_id": stream_id,
+ "message_json": message_json,
+ }
+ for user_id, messages_by_device in local_by_user_then_device.items()
+ for device_id, message_json in messages_by_device.items()
+ ],
)
- rows = []
- for user_id, messages_by_device in local_by_user_then_device.items():
- for device_id, message_json in messages_by_device.items():
- rows.append((user_id, device_id, stream_id, message_json))
-
- txn.executemany(sql, rows)
|