summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/8133.misc1
-rw-r--r--synapse/crypto/keyring.py7
-rw-r--r--synapse/module_api/__init__.py10
-rw-r--r--synapse/spam_checker_api/__init__.py6
-rw-r--r--synapse/storage/databases/main/group_server.py7
-rw-r--r--synapse/storage/databases/main/keys.py28
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py13
7 files changed, 41 insertions, 31 deletions
diff --git a/changelog.d/8133.misc b/changelog.d/8133.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8133.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 28ef7cfdb9..81c4b430b2 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -757,9 +757,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             except Exception:
                 logger.exception("Error getting keys %s from %s", key_ids, server_name)
 
-        return await yieldable_gather_results(
-            get_key, keys_to_fetch.items()
-        ).addCallback(lambda _: results)
+        await yieldable_gather_results(get_key, keys_to_fetch.items())
+        return results
 
     async def get_server_verify_key_v2_direct(self, server_name, key_ids):
         """
@@ -769,7 +768,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             key_ids (iterable[str]):
 
         Returns:
-            Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
+            dict[str, FetchKeyResult]: map from key ID to lookup result
 
         Raises:
             KeyLookupError if there was a problem making the lookup
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index c2fb757d9a..ae0e359a77 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -167,8 +167,10 @@ class ModuleApi(object):
             external_id: id on that system
             user_id: complete mxid that it is mapped to
         """
-        return self._store.record_user_external_id(
-            auth_provider_id, remote_user_id, registered_user_id
+        return defer.ensureDeferred(
+            self._store.record_user_external_id(
+                auth_provider_id, remote_user_id, registered_user_id
+            )
         )
 
     def generate_short_term_login_token(
@@ -223,7 +225,9 @@ class ModuleApi(object):
         Returns:
             Deferred[object]: result of func
         """
-        return self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
+        return defer.ensureDeferred(
+            self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
+        )
 
     def complete_sso_login(
         self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py
index 4d9b13ac04..7f63f1bfa0 100644
--- a/synapse/spam_checker_api/__init__.py
+++ b/synapse/spam_checker_api/__init__.py
@@ -48,8 +48,10 @@ class SpamCheckerApi(object):
             twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
                 The filtered state events in the room.
         """
-        state_ids = yield self._store.get_filtered_current_state_ids(
-            room_id=room_id, state_filter=StateFilter.from_types(types)
+        state_ids = yield defer.ensureDeferred(
+            self._store.get_filtered_current_state_ids(
+                room_id=room_id, state_filter=StateFilter.from_types(types)
+            )
         )
         state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
         return state.values()
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 380db3a3f3..0e3b8739c6 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -341,14 +341,15 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_users_for_summary_by_role", _get_users_for_summary_txn
         )
 
-    def is_user_in_group(self, user_id, group_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def is_user_in_group(self, user_id: str, group_id: str) -> bool:
+        result = await self.db_pool.simple_select_one_onecol(
             table="group_users",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="user_id",
             allow_none=True,
             desc="is_user_in_group",
-        ).addCallback(lambda r: bool(r))
+        )
+        return bool(result)
 
     def is_user_admin_in_group(self, group_id, user_id):
         return self.db_pool.simple_select_one_onecol(
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 384e9c5eb0..fadcad51e7 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,6 +16,7 @@
 
 import itertools
 import logging
+from typing import Iterable, Tuple
 
 from signedjson.key import decode_verify_key_bytes
 
@@ -88,12 +89,17 @@ class KeyStore(SQLBaseStore):
 
         return self.db_pool.runInteraction("get_server_verify_keys", _txn)
 
-    def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+    async def store_server_verify_keys(
+        self,
+        from_server: str,
+        ts_added_ms: int,
+        verify_keys: Iterable[Tuple[str, str, FetchKeyResult]],
+    ) -> None:
         """Stores NACL verification keys for remote servers.
         Args:
-            from_server (str): Where the verification keys were looked up
-            ts_added_ms (int): The time to record that the key was added
-            verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+            from_server: Where the verification keys were looked up
+            ts_added_ms: The time to record that the key was added
+            verify_keys:
                 keys to be stored. Each entry is a triplet of
                 (server_name, key_id, key).
         """
@@ -115,13 +121,7 @@ class KeyStore(SQLBaseStore):
             # param, which is itself the 2-tuple (server_name, key_id).
             invalidations.append((server_name, key_id))
 
-        def _invalidate(res):
-            f = self._get_server_verify_key.invalidate
-            for i in invalidations:
-                f((i,))
-            return res
-
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "store_server_verify_keys",
             self.db_pool.simple_upsert_many_txn,
             table="server_signature_keys",
@@ -134,7 +134,11 @@ class KeyStore(SQLBaseStore):
                 "verify_key",
             ),
             value_values=value_values,
-        ).addCallback(_invalidate)
+        )
+
+        invalidate = self._get_server_verify_key.invalidate
+        for i in invalidations:
+            invalidate((i,))
 
     def store_server_keys_json(
         self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index da23fe7355..e3547e53b3 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -13,30 +13,29 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import operator
-
 from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedList
 
 
 class UserErasureWorkerStore(SQLBaseStore):
     @cached()
-    def is_user_erased(self, user_id):
+    async def is_user_erased(self, user_id: str) -> bool:
         """
         Check if the given user id has requested erasure
 
         Args:
-            user_id (str): full user id to check
+            user_id: full user id to check
 
         Returns:
-            Deferred[bool]: True if the user has requested erasure
+            True if the user has requested erasure
         """
-        return self.db_pool.simple_select_onecol(
+        result = await self.db_pool.simple_select_onecol(
             table="erased_users",
             keyvalues={"user_id": user_id},
             retcol="1",
             desc="is_user_erased",
-        ).addCallback(operator.truth)
+        )
+        return bool(result)
 
     @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
     async def are_users_erased(self, user_ids):