summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/15215.misc1
-rw-r--r--synapse/storage/database.py10
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py24
3 files changed, 26 insertions, 9 deletions
diff --git a/changelog.d/15215.misc b/changelog.d/15215.misc
new file mode 100644
index 0000000000..fe52a56a7e
--- /dev/null
+++ b/changelog.d/15215.misc
@@ -0,0 +1 @@
+Refactor database transaction for query users' devices to reduce database pool contention.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index feaa6cdd07..5efe31aa19 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -672,7 +672,15 @@ class DatabasePool:
             f = cast(types.FunctionType, func)  # type: ignore[redundant-cast]
             if f.__closure__:
                 for i, cell in enumerate(f.__closure__):
-                    if inspect.isgenerator(cell.cell_contents):
+                    try:
+                        contents = cell.cell_contents
+                    except ValueError:
+                        # cell.cell_contents can raise if the "cell" is empty,
+                        # which indicates that the variable is currently
+                        # unbound.
+                        continue
+
+                    if inspect.isgenerator(contents):
                         logger.error(
                             "Programming error: function %s references generator %s "
                             "via its closure",
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b9c39b1718..a3b6c8ae8e 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -244,9 +244,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         set_tag("include_all_devices", include_all_devices)
         set_tag("include_deleted_devices", include_deleted_devices)
 
-        result = await self.db_pool.runInteraction(
-            "get_e2e_device_keys",
-            self._get_e2e_device_keys_txn,
+        result = await self._get_e2e_device_keys(
             query_list,
             include_all_devices,
             include_deleted_devices,
@@ -285,9 +283,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         log_kv(result)
         return result
 
-    def _get_e2e_device_keys_txn(
+    async def _get_e2e_device_keys(
         self,
-        txn: LoggingTransaction,
         query_list: Collection[Tuple[str, Optional[str]]],
         include_all_devices: bool = False,
         include_deleted_devices: bool = False,
@@ -319,7 +316,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         if user_list:
             user_id_in_list_clause, user_args = make_in_list_sql_clause(
-                txn.database_engine, "user_id", user_list
+                self.database_engine, "user_id", user_list
             )
             query_clauses.append(user_id_in_list_clause)
             query_params_list.append(user_args)
@@ -332,13 +329,16 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                     user_device_id_in_list_clause,
                     user_device_args,
                 ) = make_tuple_in_list_sql_clause(
-                    txn.database_engine, ("user_id", "device_id"), user_device_batch
+                    self.database_engine, ("user_id", "device_id"), user_device_batch
                 )
                 query_clauses.append(user_device_id_in_list_clause)
                 query_params_list.append(user_device_args)
 
         result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
-        for query_clause, query_params in zip(query_clauses, query_params_list):
+
+        def get_e2e_device_keys_txn(
+            txn: LoggingTransaction, query_clause: str, query_params: list
+        ) -> None:
             sql = (
                 "SELECT user_id, device_id, "
                 "    d.display_name, "
@@ -361,6 +361,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                     display_name, db_to_json(key_json) if key_json else None
                 )
 
+        for query_clause, query_params in zip(query_clauses, query_params_list):
+            await self.db_pool.runInteraction(
+                "_get_e2e_device_keys",
+                get_e2e_device_keys_txn,
+                query_clause,
+                query_params,
+            )
+
         if include_deleted_devices:
             for user_id, device_id in deleted_devices:
                 if device_id is None: