summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-27 07:41:01 -0400
committerGitHub <noreply@github.com>2020-08-27 07:41:01 -0400
commit30426c7063a7e5567ac21cd10267651ef1935360 (patch)
treed18f000a9576bc52742301ed37793c6563914004 /synapse/storage
parentConvert simple_update* and simple_select* to async (#8173) (diff)
downloadsynapse-30426c7063a7e5567ac21cd10267651ef1935360.tar.xz
Convert additional database methods to async (select list, search, insert_many, delete_*) (#8168)
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/background_updates.py9
-rw-r--r--synapse/storage/database.py99
-rw-r--r--synapse/storage/databases/main/__init__.py12
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py15
-rw-r--r--synapse/storage/databases/main/media_repository.py4
5 files changed, 60 insertions, 79 deletions
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 56818f4df8..0db900fa0e 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -414,13 +414,14 @@ class BackgroundUpdater(object):
 
         self.register_background_update_handler(update_name, updater)
 
-    def _end_background_update(self, update_name):
+    async def _end_background_update(self, update_name: str) -> None:
         """Removes a completed background update task from the queue.
 
         Args:
-            update_name(str): The name of the completed task to remove
+            update_name:: The name of the completed task to remove
+
         Returns:
-            A deferred that completes once the task is removed.
+            None, completes once the task is removed.
         """
         if update_name != self._current_background_update:
             raise Exception(
@@ -428,7 +429,7 @@ class BackgroundUpdater(object):
                 % update_name
             )
         self._current_background_update = None
-        return self.db_pool.simple_delete_one(
+        await self.db_pool.simple_delete_one(
             "background_updates", keyvalues={"update_name": update_name}
         )
 
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 38010af600..2f6f49a4bf 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -605,7 +605,13 @@ class DatabasePool(object):
         results = [dict(zip(col_headers, row)) for row in cursor]
         return results
 
-    def execute(self, desc: str, decoder: Callable, query: str, *args: Any):
+    async def execute(
+        self,
+        desc: str,
+        decoder: Optional[Callable[[Cursor], R]],
+        query: str,
+        *args: Any
+    ) -> R:
         """Runs a single query for a result set.
 
         Args:
@@ -614,7 +620,7 @@ class DatabasePool(object):
             query - The query string to execute
             *args - Query args.
         Returns:
-            Deferred which results to the result of decoder(results)
+            The result of decoder(results)
         """
 
         def interaction(txn):
@@ -624,7 +630,7 @@ class DatabasePool(object):
             else:
                 return txn.fetchall()
 
-        return self.runInteraction(desc, interaction)
+        return await self.runInteraction(desc, interaction)
 
     # "Simple" SQL API methods that operate on a single table with no JOINs,
     # no complex WHERE clauses, just a dict of values for columns.
@@ -673,15 +679,30 @@ class DatabasePool(object):
 
         txn.execute(sql, vals)
 
-    def simple_insert_many(
+    async def simple_insert_many(
         self, table: str, values: List[Dict[str, Any]], desc: str
-    ) -> defer.Deferred:
-        return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
+    ) -> None:
+        """Executes an INSERT query on the named table.
+
+        Args:
+            table: string giving the table name
+            values: dict of new column names and values for them
+            desc: string giving a description of the transaction
+        """
+        await self.runInteraction(desc, self.simple_insert_many_txn, table, values)
 
     @staticmethod
     def simple_insert_many_txn(
         txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
     ) -> None:
+        """Executes an INSERT query on the named table.
+
+        Args:
+            txn: The transaction to use.
+            table: string giving the table name
+            values: dict of new column names and values for them
+            desc: string giving a description of the transaction
+        """
         if not values:
             return
 
@@ -1397,9 +1418,9 @@ class DatabasePool(object):
 
         return dict(zip(retcols, row))
 
-    def simple_delete_one(
+    async def simple_delete_one(
         self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
-    ) -> defer.Deferred:
+    ) -> None:
         """Executes a DELETE query on the named table, expecting to delete a
         single row.
 
@@ -1407,7 +1428,7 @@ class DatabasePool(object):
             table: string giving the table name
             keyvalues: dict of column names and values to select the row with
         """
-        return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
+        await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
 
     @staticmethod
     def simple_delete_one_txn(
@@ -1446,15 +1467,15 @@ class DatabasePool(object):
         txn.execute(sql, list(keyvalues.values()))
         return txn.rowcount
 
-    def simple_delete_many(
+    async def simple_delete_many(
         self,
         table: str,
         column: str,
         iterable: Iterable[Any],
         keyvalues: Dict[str, Any],
         desc: str,
-    ) -> defer.Deferred:
-        return self.runInteraction(
+    ) -> int:
+        return await self.runInteraction(
             desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
         )
 
@@ -1537,52 +1558,6 @@ class DatabasePool(object):
 
         return cache, min_val
 
-    def simple_select_list_paginate(
-        self,
-        table: str,
-        orderby: str,
-        start: int,
-        limit: int,
-        retcols: Iterable[str],
-        filters: Optional[Dict[str, Any]] = None,
-        keyvalues: Optional[Dict[str, Any]] = None,
-        order_direction: str = "ASC",
-        desc: str = "simple_select_list_paginate",
-    ) -> defer.Deferred:
-        """
-        Executes a SELECT query on the named table with start and limit,
-        of row numbers, which may return zero or number of rows from start to limit,
-        returning the result as a list of dicts.
-
-        Args:
-            table: the table name
-            orderby: Column to order the results by.
-            start: Index to begin the query at.
-            limit: Number of results to return.
-            retcols: the names of the columns to return
-            filters:
-                column names and values to filter the rows with, or None to not
-                apply a WHERE ? LIKE ? clause.
-            keyvalues:
-                column names and values to select the rows with, or None to not
-                apply a WHERE clause.
-            order_direction: Whether the results should be ordered "ASC" or "DESC".
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        return self.runInteraction(
-            desc,
-            self.simple_select_list_paginate_txn,
-            table,
-            orderby,
-            start,
-            limit,
-            retcols,
-            filters=filters,
-            keyvalues=keyvalues,
-            order_direction=order_direction,
-        )
-
     @classmethod
     def simple_select_list_paginate_txn(
         cls,
@@ -1647,14 +1622,14 @@ class DatabasePool(object):
 
         return cls.cursor_to_dict(txn)
 
-    def simple_search_list(
+    async def simple_search_list(
         self,
         table: str,
         term: Optional[str],
         col: str,
         retcols: Iterable[str],
         desc="simple_search_list",
-    ):
+    ) -> Optional[List[Dict[str, Any]]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
@@ -1665,10 +1640,10 @@ class DatabasePool(object):
             retcols: the names of the columns to return
 
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]] or None
+            A list of dictionaries or None.
         """
 
-        return self.runInteraction(
+        return await self.runInteraction(
             desc, self.simple_search_list_txn, table, term, col, retcols
         )
 
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 8b9b6eb472..70cf15dd7f 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -18,7 +18,7 @@
 import calendar
 import logging
 import time
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Optional
 
 from synapse.api.constants import PresenceState
 from synapse.config.homeserver import HomeServerConfig
@@ -559,17 +559,17 @@ class DataStore(
             "get_users_paginate_txn", get_users_paginate_txn
         )
 
-    def search_users(self, term):
+    async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
         """Function to search users list for one or more users with
         the matched term.
 
         Args:
-            term (str): search term
-            col (str): column to query term should be matched to
+            term: search term
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
+            A list of dictionaries or None.
         """
-        return self.db_pool.simple_search_list(
+        return await self.db_pool.simple_search_list(
             table="users",
             term=term,
             col="name",
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 385868bdab..af0b85e2c9 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,7 +14,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 
@@ -27,6 +27,9 @@ from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.iterutils import batch_iter
 
+if TYPE_CHECKING:
+    from synapse.handlers.e2e_keys import SignatureListItem
+
 
 class EndToEndKeyWorkerStore(SQLBaseStore):
     @trace
@@ -730,14 +733,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 stream_id,
             )
 
-    def store_e2e_cross_signing_signatures(self, user_id, signatures):
+    async def store_e2e_cross_signing_signatures(
+        self, user_id: str, signatures: "Iterable[SignatureListItem]"
+    ) -> None:
         """Stores cross-signing signatures.
 
         Args:
-            user_id (str): the user who made the signatures
-            signatures (iterable[SignatureListItem]): signatures to add
+            user_id: the user who made the signatures
+            signatures: signatures to add
         """
-        return self.db_pool.simple_insert_many(
+        await self.db_pool.simple_insert_many(
             "e2e_cross_signing_signatures",
             [
                 {
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index fc223f5a2a..8361dd63d9 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -314,14 +314,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_remote_media_thumbnail",
         )
 
-    def get_remote_media_before(self, before_ts):
+    async def get_remote_media_before(self, before_ts):
         sql = (
             "SELECT media_origin, media_id, filesystem_id"
             " FROM remote_media_cache"
             " WHERE last_access_ts < ?"
         )
 
-        return self.db_pool.execute(
+        return await self.db_pool.execute(
             "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
         )