summary refs log tree commit diff
diff options
context:
space:
mode:
authorOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2021-12-13 17:34:26 +0000
committerOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2021-12-13 17:41:27 +0000
commitb17f575d4212f8c7aaeb75195e67f4fa63d68e71 (patch)
treea1474193da7f3f7583ab1be53fd9bf1106d25327
parentFix up tests that weren't expecting extra call arguments (diff)
downloadsynapse-b17f575d4212f8c7aaeb75195e67f4fa63d68e71.tar.xz
Count the OTKs in bulk
-rw-r--r--synapse/appservice/scheduler.py3
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py52
2 files changed, 52 insertions, 3 deletions
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 686d98e791..0d11297f2b 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -280,8 +280,9 @@ class _ServiceQueuer:
         Given a list of application service users that are interesting,
         compute one-time key counts and fallback key usages for the users.
         """
+        otk_counts = await self._store.count_bulk_e2e_one_time_keys_for_as(users)
         # OSTD implement me!
-        return {}, {}
+        return otk_counts, {}
 
 
 class _TransactionController:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b06c1dc45b..9b1c0f12d4 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import abc
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -22,9 +22,14 @@ from canonicaljson import encode_canonical_json
 from twisted.enterprise.adbapi import Connection
 
 from synapse.api.constants import DeviceKeyAlgorithms
+from synapse.appservice import TransactionOneTimeKeyCounts
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingTransaction,
+    make_in_list_sql_clause,
+)
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict
@@ -397,6 +402,49 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             "count_e2e_one_time_keys", _count_e2e_one_time_keys
         )
 
+    async def count_bulk_e2e_one_time_keys_for_as(
+        self, user_ids: Collection[str]
+    ) -> TransactionOneTimeKeyCounts:
+        """
+        Counts, in bulk, the one-time keys for all the users specified.
+        Intended to be used by application services for populating OTK counts in
+        transactions.
+
+        Return structure is of the shape:
+          user_id -> device_id -> algorithm -> count
+        """
+
+        def _count_bulk_e2e_one_time_keys_txn(
+            txn: LoggingTransaction,
+        ) -> TransactionOneTimeKeyCounts:
+            user_in_where_clause, user_parameters = make_in_list_sql_clause(
+                self.database_engine, "user_id", user_ids
+            )
+            sql = f"""
+                SELECT user_id, device_id, algorithm, COUNT(key_id)
+                FROM devices
+                LEFT JOIN e2e_one_time_keys_json USING (user_id, device_id)
+                WHERE {user_in_where_clause}
+                GROUP BY user_id, device_id, algorithm
+            """
+            txn.execute(sql, user_parameters)
+
+            result = {}
+
+            for user_id, device_id, algorithm, count in txn:
+                device_count_by_algo = result.setdefault(user_id, {}).setdefault(
+                    device_id, {}
+                )
+                if algorithm is not None:
+                    # algorithm will be None if this device has no keys.
+                    device_count_by_algo[algorithm] = count
+
+            return result
+
+        return await self.db_pool.runInteraction(
+            "count_bulk_e2e_one_time_keys", _count_bulk_e2e_one_time_keys_txn
+        )
+
     async def set_e2e_fallback_keys(
         self, user_id: str, device_id: str, fallback_keys: JsonDict
     ) -> None: