diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1b6ccd51c8..c128889bf9 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -513,21 +514,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
for user_chunk in batch_iter(user_ids, 100):
clause, params = make_in_list_sql_clause(
- txn.database_engine, "k.user_id", user_chunk
- )
- sql = (
- """
- SELECT k.user_id, k.keytype, k.keydata, k.stream_id
- FROM e2e_cross_signing_keys k
- INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
- FROM e2e_cross_signing_keys
- GROUP BY user_id, keytype) s
- USING (user_id, stream_id, keytype)
- WHERE
- """
- + clause
+ txn.database_engine, "user_id", user_chunk
)
+ # Fetch the latest key for each type per user.
+ if isinstance(self.database_engine, PostgresEngine):
+ # The `DISTINCT ON` clause will pick the *first* row it
+ # encounters, so ordering by stream ID desc will ensure we get
+ # the latest key.
+ sql = """
+ SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id
+ FROM e2e_cross_signing_keys
+ WHERE %(clause)s
+ ORDER BY user_id, keytype, stream_id DESC
+ """ % {
+ "clause": clause
+ }
+ else:
+ # SQLite has special handling for bare columns when using
+ # MIN/MAX with a `GROUP BY` clause where it picks the value from
+ # a row that matches the MIN/MAX.
+ sql = """
+ SELECT user_id, keytype, keydata, MAX(stream_id)
+ FROM e2e_cross_signing_keys
+ WHERE %(clause)s
+ GROUP BY user_id, keytype
+ """ % {
+ "clause": clause
+ }
+
txn.execute(sql, params)
rows = self.db_pool.cursor_to_dict(txn)
|