summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-27 13:38:41 -0400
committerGitHub <noreply@github.com>2020-08-27 13:38:41 -0400
commit9b7ac03af3e7ceae7d1933db566ee407cfdef72d (patch)
treebd29b6da47cb08b846e05ce004f0e8d4008ed374
parentsimple_search_list_txn should return None, not 0. (#8187) (diff)
downloadsynapse-9b7ac03af3e7ceae7d1933db566ee407cfdef72d.tar.xz
Convert calls of async database methods to async (#8166)
-rw-r--r--changelog.d/8166.misc1
-rw-r--r--synapse/federation/persistence.py16
-rw-r--r--synapse/federation/units.py4
-rw-r--r--synapse/storage/databases/main/appservice.py6
-rw-r--r--synapse/storage/databases/main/devices.py4
-rw-r--r--synapse/storage/databases/main/group_server.py30
-rw-r--r--synapse/storage/databases/main/keys.py26
-rw-r--r--synapse/storage/databases/main/media_repository.py22
-rw-r--r--synapse/storage/databases/main/openid.py6
-rw-r--r--synapse/storage/databases/main/profile.py10
-rw-r--r--synapse/storage/databases/main/registration.py29
-rw-r--r--synapse/storage/databases/main/room.py16
-rw-r--r--synapse/storage/databases/main/stats.py10
-rw-r--r--synapse/storage/databases/main/transactions.py18
14 files changed, 114 insertions, 84 deletions
diff --git a/changelog.d/8166.misc b/changelog.d/8166.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8166.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index d68b4bd670..769cd5de28 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -21,7 +21,9 @@ These actions are mostly only used by the :py:mod:`.replication` module.
 
 import logging
 
+from synapse.federation.units import Transaction
 from synapse.logging.utils import log_function
+from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
 
@@ -49,15 +51,15 @@ class TransactionActions(object):
         return self.store.get_received_txn_response(transaction.transaction_id, origin)
 
     @log_function
-    def set_response(self, origin, transaction, code, response):
+    async def set_response(
+        self, origin: str, transaction: Transaction, code: int, response: JsonDict
+    ) -> None:
         """ Persist how we responded to a transaction.
-
-        Returns:
-            Deferred
         """
-        if not transaction.transaction_id:
+        transaction_id = transaction.transaction_id  # type: ignore
+        if not transaction_id:
             raise RuntimeError("Cannot persist a transaction with no transaction_id")
 
-        return self.store.set_received_txn_response(
-            transaction.transaction_id, origin, code, response
+        await self.store.set_received_txn_response(
+            transaction_id, origin, code, response
         )
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 6b32e0dcbf..64d98fc8f6 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -107,9 +107,7 @@ class Transaction(JsonEncodedObject):
         if "edus" in kwargs and not kwargs["edus"]:
             del kwargs["edus"]
 
-        super(Transaction, self).__init__(
-            transaction_id=transaction_id, pdus=pdus, **kwargs
-        )
+        super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs)
 
     @staticmethod
     def create_new(pdus, **kwargs):
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 77723f7d4d..92f56f1602 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -161,16 +161,14 @@ class ApplicationServiceTransactionWorkerStore(
             return result.get("state")
         return None
 
-    def set_appservice_state(self, service, state):
+    async def set_appservice_state(self, service, state) -> None:
         """Set the application service state.
 
         Args:
             service(ApplicationService): The service whose state to set.
             state(ApplicationServiceState): The connectivity state to apply.
-        Returns:
-            An Awaitable which resolves when the state was set successfully.
         """
-        return self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             "application_services_state", {"as_id": service.id}, {"state": state}
         )
 
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a811a39eb5..ecd3f3b310 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -716,11 +716,11 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return {row["user_id"] for row in rows}
 
-    def mark_remote_user_device_cache_as_stale(self, user_id: str):
+    async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
         """Records that the server has reason to believe the cache of the devices
         for the remote users is out of date.
         """
-        return self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             table="device_lists_remote_resync",
             keyvalues={"user_id": user_id},
             values={},
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index e3ead71853..8acf254bf3 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -742,7 +742,13 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="remove_room_from_summary",
         )
 
-    def upsert_group_category(self, group_id, category_id, profile, is_public):
+    async def upsert_group_category(
+        self,
+        group_id: str,
+        category_id: str,
+        profile: Optional[JsonDict],
+        is_public: Optional[bool],
+    ) -> None:
         """Add/update room category for group
         """
         insertion_values = {}
@@ -758,7 +764,7 @@ class GroupServerStore(GroupServerWorkerStore):
         else:
             update_values["is_public"] = is_public
 
-        return self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             table="group_room_categories",
             keyvalues={"group_id": group_id, "category_id": category_id},
             values=update_values,
@@ -773,7 +779,13 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="remove_group_category",
         )
 
-    def upsert_group_role(self, group_id, role_id, profile, is_public):
+    async def upsert_group_role(
+        self,
+        group_id: str,
+        role_id: str,
+        profile: Optional[JsonDict],
+        is_public: Optional[bool],
+    ) -> None:
         """Add/remove user role
         """
         insertion_values = {}
@@ -789,7 +801,7 @@ class GroupServerStore(GroupServerWorkerStore):
         else:
             update_values["is_public"] = is_public
 
-        return self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             table="group_roles",
             keyvalues={"group_id": group_id, "role_id": role_id},
             values=update_values,
@@ -938,10 +950,10 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="remove_user_from_summary",
         )
 
-    def add_group_invite(self, group_id, user_id):
+    async def add_group_invite(self, group_id: str, user_id: str) -> None:
         """Record that the group server has invited a user
         """
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             table="group_invites",
             values={"group_id": group_id, "user_id": user_id},
             desc="add_group_invite",
@@ -1044,8 +1056,10 @@ class GroupServerStore(GroupServerWorkerStore):
             "remove_user_from_group", _remove_user_from_group_txn
         )
 
-    def add_room_to_group(self, group_id, room_id, is_public):
-        return self.db_pool.simple_insert(
+    async def add_room_to_group(
+        self, group_id: str, room_id: str, is_public: bool
+    ) -> None:
+        await self.db_pool.simple_insert(
             table="group_rooms",
             values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
             desc="add_room_to_group",
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index fadcad51e7..1c0a049c55 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -140,22 +140,28 @@ class KeyStore(SQLBaseStore):
         for i in invalidations:
             invalidate((i,))
 
-    def store_server_keys_json(
-        self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
-    ):
+    async def store_server_keys_json(
+        self,
+        server_name: str,
+        key_id: str,
+        from_server: str,
+        ts_now_ms: int,
+        ts_expires_ms: int,
+        key_json_bytes: bytes,
+    ) -> None:
         """Stores the JSON bytes for a set of keys from a server
         The JSON should be signed by the originating server, the intermediate
         server, and by this server. Updates the value for the
         (server_name, key_id, from_server) triplet if one already existed.
         Args:
-            server_name (str): The name of the server.
-            key_id (str): The identifer of the key this JSON is for.
-            from_server (str): The server this JSON was fetched from.
-            ts_now_ms (int): The time now in milliseconds.
-            ts_valid_until_ms (int): The time when this json stops being valid.
-            key_json (bytes): The encoded JSON.
+            server_name: The name of the server.
+            key_id: The identifer of the key this JSON is for.
+            from_server: The server this JSON was fetched from.
+            ts_now_ms: The time now in milliseconds.
+            ts_valid_until_ms: The time when this json stops being valid.
+            key_json_bytes: The encoded JSON.
         """
-        return self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             table="server_keys_json",
             keyvalues={
                 "server_name": server_name,
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 8361dd63d9..3919ecad69 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -60,7 +60,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="get_local_media",
         )
 
-    def store_local_media(
+    async def store_local_media(
         self,
         media_id,
         media_type,
@@ -69,8 +69,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         media_length,
         user_id,
         url_cache=None,
-    ):
-        return self.db_pool.simple_insert(
+    ) -> None:
+        await self.db_pool.simple_insert(
             "local_media_repository",
             {
                 "media_id": media_id,
@@ -141,10 +141,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
         return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
 
-    def store_url_cache(
+    async def store_url_cache(
         self, url, response_code, etag, expires_ts, og, media_id, download_ts
     ):
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             "local_media_repository_url_cache",
             {
                 "url": url,
@@ -172,7 +172,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="get_local_media_thumbnails",
         )
 
-    def store_local_thumbnail(
+    async def store_local_thumbnail(
         self,
         media_id,
         thumbnail_width,
@@ -181,7 +181,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         thumbnail_method,
         thumbnail_length,
     ):
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             "local_media_repository_thumbnails",
             {
                 "media_id": media_id,
@@ -212,7 +212,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="get_cached_remote_media",
         )
 
-    def store_cached_remote_media(
+    async def store_cached_remote_media(
         self,
         origin,
         media_id,
@@ -222,7 +222,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         upload_name,
         filesystem_id,
     ):
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             "remote_media_cache",
             {
                 "media_origin": origin,
@@ -288,7 +288,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="get_remote_media_thumbnails",
         )
 
-    def store_remote_media_thumbnail(
+    async def store_remote_media_thumbnail(
         self,
         origin,
         media_id,
@@ -299,7 +299,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         thumbnail_method,
         thumbnail_length,
     ):
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             "remote_media_cache_thumbnails",
             {
                 "media_origin": origin,
diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py
index dcd1ff911a..4db8949da7 100644
--- a/synapse/storage/databases/main/openid.py
+++ b/synapse/storage/databases/main/openid.py
@@ -2,8 +2,10 @@ from synapse.storage._base import SQLBaseStore
 
 
 class OpenIdStore(SQLBaseStore):
-    def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
-        return self.db_pool.simple_insert(
+    async def insert_open_id_token(
+        self, token: str, ts_valid_until_ms: int, user_id: str
+    ) -> None:
+        await self.db_pool.simple_insert(
             table="open_id_tokens",
             values={
                 "token": token,
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 858fd92420..301875a672 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -66,8 +66,8 @@ class ProfileWorkerStore(SQLBaseStore):
             desc="get_from_remote_profile_cache",
         )
 
-    def create_profile(self, user_localpart):
-        return self.db_pool.simple_insert(
+    async def create_profile(self, user_localpart: str) -> None:
+        await self.db_pool.simple_insert(
             table="profiles", values={"user_id": user_localpart}, desc="create_profile"
         )
 
@@ -93,13 +93,15 @@ class ProfileWorkerStore(SQLBaseStore):
 
 
 class ProfileStore(ProfileWorkerStore):
-    def add_remote_profile_cache(self, user_id, displayname, avatar_url):
+    async def add_remote_profile_cache(
+        self, user_id: str, displayname: str, avatar_url: str
+    ) -> None:
         """Ensure we are caching the remote user's profiles.
 
         This should only be called when `is_subscribed_remote_profile_for_user`
         would return true for the user.
         """
-        return self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
             values={
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 48bda66f3e..28f7ae0430 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,7 +17,7 @@
 
 import logging
 import re
-from typing import Any, Awaitable, Dict, List, Optional
+from typing import Any, Dict, List, Optional
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -549,23 +549,22 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="user_delete_threepids",
         )
 
-    def add_user_bound_threepid(self, user_id, medium, address, id_server):
+    async def add_user_bound_threepid(
+        self, user_id: str, medium: str, address: str, id_server: str
+    ):
         """The server proxied a bind request to the given identity server on
         behalf of the given user. We need to remember this in case the user
         asks us to unbind the threepid.
 
         Args:
-            user_id (str)
-            medium (str)
-            address (str)
-            id_server (str)
-
-        Returns:
-            Awaitable
+            user_id
+            medium
+            address
+            id_server
         """
         # We need to use an upsert, in case they user had already bound the
         # threepid
-        return self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             table="user_threepid_id_server",
             keyvalues={
                 "user_id": user_id,
@@ -1083,9 +1082,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-    def record_user_external_id(
+    async def record_user_external_id(
         self, auth_provider: str, external_id: str, user_id: str
-    ) -> Awaitable:
+    ) -> None:
         """Record a mapping from an external user id to a mxid
 
         Args:
@@ -1093,7 +1092,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             external_id: id on that system
             user_id: complete mxid that it is mapped to
         """
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             table="user_external_ids",
             values={
                 "auth_provider": auth_provider,
@@ -1237,12 +1236,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
         return res if res else False
 
-    def add_user_pending_deactivation(self, user_id):
+    async def add_user_pending_deactivation(self, user_id: str) -> None:
         """
         Adds a user to the table of users who need to be parted from all the rooms they're
         in
         """
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             "users_pending_deactivation",
             values={"user_id": user_id},
             desc="add_user_pending_deactivation",
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 66d7135413..a92641c339 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -27,7 +27,7 @@ from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.search import SearchStore
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -1296,11 +1296,17 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
 
         return self.db_pool.runInteraction("get_rooms", f)
 
-    def add_event_report(
-        self, room_id, event_id, user_id, reason, content, received_ts
-    ):
+    async def add_event_report(
+        self,
+        room_id: str,
+        event_id: str,
+        user_id: str,
+        reason: str,
+        content: JsonDict,
+        received_ts: int,
+    ) -> None:
         next_id = self._event_reports_id_gen.get_next()
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             table="event_reports",
             values={
                 "id": next_id,
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 9fe97af56a..7af2608ca4 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@
 
 import logging
 from itertools import chain
-from typing import Tuple
+from typing import Any, Dict, Tuple
 
 from twisted.internet.defer import DeferredLock
 
@@ -222,11 +222,11 @@ class StatsStore(StateDeltasStore):
             desc="stats_incremental_position",
         )
 
-    def update_room_state(self, room_id, fields):
+    async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None:
         """
         Args:
-            room_id (str)
-            fields (dict[str:Any])
+            room_id
+            fields
         """
 
         # For whatever reason some of the fields may contain null bytes, which
@@ -244,7 +244,7 @@ class StatsStore(StateDeltasStore):
             if field and "\0" in field:
                 fields[col] = None
 
-        return self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             table="room_stats_state",
             keyvalues={"room_id": room_id},
             values=fields,
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 52668dbdf9..2efcc0dc66 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -21,6 +21,7 @@ from canonicaljson import encode_canonical_json
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
+from synapse.types import JsonDict
 from synapse.util.caches.expiringcache import ExpiringCache
 
 db_binary_type = memoryview
@@ -98,20 +99,21 @@ class TransactionStore(SQLBaseStore):
         else:
             return None
 
-    def set_received_txn_response(self, transaction_id, origin, code, response_dict):
-        """Persist the response we returened for an incoming transaction, and
+    async def set_received_txn_response(
+        self, transaction_id: str, origin: str, code: int, response_dict: JsonDict
+    ) -> None:
+        """Persist the response we returned for an incoming transaction, and
         should return for subsequent transactions with the same transaction_id
         and origin.
 
         Args:
-            txn
-            transaction_id (str)
-            origin (str)
-            code (int)
-            response_json (str)
+            transaction_id: The incoming transaction ID.
+            origin: The origin server.
+            code: The response code.
+            response_dict: The response, to be encoded into JSON.
         """
 
-        return self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             table="received_transactions",
             values={
                 "transaction_id": transaction_id,