diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index bcf746b7ef..d0b9742695 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -25,7 +25,9 @@ from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import make_in_list_sql_clause
from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -391,26 +393,24 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"""
result = {}
- batch_size = 100
- chunks = [
- user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
- ]
- for user_chunk in chunks:
- sql = """
+ for user_chunk in batch_iter(user_ids, 100):
+ clause, params = make_in_list_sql_clause(
+ txn.database_engine, "k.user_id", user_chunk
+ )
+ sql = (
+ """
SELECT k.user_id, k.keytype, k.keydata, k.stream_id
FROM e2e_cross_signing_keys k
INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
FROM e2e_cross_signing_keys
GROUP BY user_id, keytype) s
USING (user_id, stream_id, keytype)
- WHERE k.user_id IN (%s)
- """ % (
- ",".join("?" for u in user_chunk),
+ WHERE
+ """
+ + clause
)
- query_params = []
- query_params.extend(user_chunk)
- txn.execute(sql, query_params)
+ txn.execute(sql, params)
rows = self.db.cursor_to_dict(txn)
for row in rows:
@@ -453,15 +453,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
device_id = k
devices[(user_id, device_id)] = key_type
- device_list = list(devices)
-
- # split into batches
- batch_size = 100
- chunks = [
- device_list[i : i + batch_size]
- for i in range(0, len(device_list), batch_size)
- ]
- for user_chunk in chunks:
+ for batch in batch_iter(devices.keys(), size=100):
sql = """
SELECT target_user_id, target_device_id, key_id, signature
FROM e2e_cross_signing_signatures
@@ -469,11 +461,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
AND (%s)
""" % (
" OR ".join(
- "(target_user_id = ? AND target_device_id = ?)" for d in devices
+ "(target_user_id = ? AND target_device_id = ?)" for _ in batch
)
)
query_params = [from_user_id]
- for item in devices:
+ for item in batch:
# item is a (user_id, device_id) tuple
query_params.extend(item)
|