summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-07 12:17:17 -0400
committerGitHub <noreply@github.com>2020-08-07 12:17:17 -0400
commitf3fe6961b211d898aa347771df598c531fbca90c (patch)
tree2389d40f2742f84264b3688edff377774dbfc495
parentClarify that undoing a shutdown might not be possible (#8010) (diff)
downloadsynapse-f3fe6961b211d898aa347771df598c531fbca90c.tar.xz
Convert additional database stores to async/await (#8045)
-rw-r--r--changelog.d/8045.misc1
-rw-r--r--synapse/storage/databases/main/client_ips.py54
-rw-r--r--synapse/storage/databases/main/search.py69
-rw-r--r--synapse/storage/databases/main/signatures.py7
-rw-r--r--synapse/storage/databases/main/user_directory.py124
-rw-r--r--tests/storage/test_user_directory.py4
6 files changed, 107 insertions, 152 deletions
diff --git a/changelog.d/8045.misc b/changelog.d/8045.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8045.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 50d71f5ebc..216a5925fc 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -14,8 +14,7 @@
 # limitations under the License.
 
 import logging
-
-from twisted.internet import defer
+from typing import Dict, Optional, Tuple
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
@@ -82,21 +81,19 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             "devices_last_seen", self._devices_last_seen_update
         )
 
-    @defer.inlineCallbacks
-    def _remove_user_ip_nonunique(self, progress, batch_size):
+    async def _remove_user_ip_nonunique(self, progress, batch_size):
         def f(conn):
             txn = conn.cursor()
             txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
             txn.close()
 
-        yield self.db_pool.runWithConnection(f)
-        yield self.db_pool.updates._end_background_update(
+        await self.db_pool.runWithConnection(f)
+        await self.db_pool.updates._end_background_update(
             "user_ips_drop_nonunique_index"
         )
         return 1
 
-    @defer.inlineCallbacks
-    def _analyze_user_ip(self, progress, batch_size):
+    async def _analyze_user_ip(self, progress, batch_size):
         # Background update to analyze user_ips table before we run the
         # deduplication background update. The table may not have been analyzed
         # for ages due to the table locks.
@@ -106,14 +103,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
         def user_ips_analyze(txn):
             txn.execute("ANALYZE user_ips")
 
-        yield self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
+        await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
 
-        yield self.db_pool.updates._end_background_update("user_ips_analyze")
+        await self.db_pool.updates._end_background_update("user_ips_analyze")
 
         return 1
 
-    @defer.inlineCallbacks
-    def _remove_user_ip_dupes(self, progress, batch_size):
+    async def _remove_user_ip_dupes(self, progress, batch_size):
         # This works function works by scanning the user_ips table in batches
         # based on `last_seen`. For each row in a batch it searches the rest of
         # the table to see if there are any duplicates, if there are then they
@@ -140,7 +136,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
                 return None
 
         # Get a last seen that has roughly `batch_size` since `begin_last_seen`
-        end_last_seen = yield self.db_pool.runInteraction(
+        end_last_seen = await self.db_pool.runInteraction(
             "user_ips_dups_get_last_seen", get_last_seen
         )
 
@@ -275,15 +271,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
                 txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
             )
 
-        yield self.db_pool.runInteraction("user_ips_dups_remove", remove)
+        await self.db_pool.runInteraction("user_ips_dups_remove", remove)
 
         if last:
-            yield self.db_pool.updates._end_background_update("user_ips_remove_dupes")
+            await self.db_pool.updates._end_background_update("user_ips_remove_dupes")
 
         return batch_size
 
-    @defer.inlineCallbacks
-    def _devices_last_seen_update(self, progress, batch_size):
+    async def _devices_last_seen_update(self, progress, batch_size):
         """Background update to insert last seen info into devices table
         """
 
@@ -346,12 +341,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
             return len(rows)
 
-        updated = yield self.db_pool.runInteraction(
+        updated = await self.db_pool.runInteraction(
             "_devices_last_seen_update", _devices_last_seen_update_txn
         )
 
         if not updated:
-            yield self.db_pool.updates._end_background_update("devices_last_seen")
+            await self.db_pool.updates._end_background_update("devices_last_seen")
 
         return updated
 
@@ -460,25 +455,25 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
                 # Failed to upsert, log and continue
                 logger.error("Failed to insert client IP %r: %r", entry, e)
 
-    @defer.inlineCallbacks
-    def get_last_client_ip_by_device(self, user_id, device_id):
+    async def get_last_client_ip_by_device(
+        self, user_id: str, device_id: Optional[str]
+    ) -> Dict[Tuple[str, str], dict]:
         """For each device_id listed, give the user_ip it was last seen on
 
         Args:
-            user_id (str)
-            device_id (str): If None fetches all devices for the user
+            user_id: The user to fetch devices for.
+            device_id: If None fetches all devices for the user
 
         Returns:
-            defer.Deferred: resolves to a dict, where the keys
-            are (user_id, device_id) tuples. The values are also dicts, with
-            keys giving the column names
+            A dictionary mapping a tuple of (user_id, device_id) to dicts, with
+            keys giving the column names from the devices table.
         """
 
         keyvalues = {"user_id": user_id}
         if device_id is not None:
             keyvalues["device_id"] = device_id
 
-        res = yield self.db_pool.simple_select_list(
+        res = await self.db_pool.simple_select_list(
             table="devices",
             keyvalues=keyvalues,
             retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
@@ -500,8 +495,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
                     }
         return ret
 
-    @defer.inlineCallbacks
-    def get_user_ip_and_agents(self, user):
+    async def get_user_ip_and_agents(self, user):
         user_id = user.to_string()
         results = {}
 
@@ -511,7 +505,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
                 user_agent, _, last_seen = self._batch_row_update[key]
                 results[(access_token, ip)] = (user_agent, last_seen)
 
-        rows = yield self.db_pool.simple_select_list(
+        rows = await self.db_pool.simple_select_list(
             table="user_ips",
             keyvalues={"user_id": user_id},
             retcols=["access_token", "ip", "user_agent", "last_seen"],
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 2162d0712d..7f8d1880e5 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -16,8 +16,7 @@
 import logging
 import re
 from collections import namedtuple
-
-from twisted.internet import defer
+from typing import List, Optional
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -114,8 +113,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
             self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
         )
 
-    @defer.inlineCallbacks
-    def _background_reindex_search(self, progress, batch_size):
+    async def _background_reindex_search(self, progress, batch_size):
         # we work through the events table from highest stream id to lowest
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
@@ -206,19 +204,18 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
             return len(event_search_rows)
 
-        result = yield self.db_pool.runInteraction(
+        result = await self.db_pool.runInteraction(
             self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
         )
 
         if not result:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 self.EVENT_SEARCH_UPDATE_NAME
             )
 
         return result
 
-    @defer.inlineCallbacks
-    def _background_reindex_gin_search(self, progress, batch_size):
+    async def _background_reindex_gin_search(self, progress, batch_size):
         """This handles old synapses which used GIST indexes, if any;
         converting them back to be GIN as per the actual schema.
         """
@@ -255,15 +252,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
                 conn.set_session(autocommit=False)
 
         if isinstance(self.database_engine, PostgresEngine):
-            yield self.db_pool.runWithConnection(create_index)
+            await self.db_pool.runWithConnection(create_index)
 
-        yield self.db_pool.updates._end_background_update(
+        await self.db_pool.updates._end_background_update(
             self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME
         )
         return 1
 
-    @defer.inlineCallbacks
-    def _background_reindex_search_order(self, progress, batch_size):
+    async def _background_reindex_search_order(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
@@ -288,12 +284,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
                 )
                 conn.set_session(autocommit=False)
 
-            yield self.db_pool.runWithConnection(create_index)
+            await self.db_pool.runWithConnection(create_index)
 
             pg = dict(progress)
             pg["have_added_indexes"] = True
 
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 self.EVENT_SEARCH_ORDER_UPDATE_NAME,
                 self.db_pool.updates._background_update_progress_txn,
                 self.EVENT_SEARCH_ORDER_UPDATE_NAME,
@@ -331,12 +327,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
             return len(rows), True
 
-        num_rows, finished = yield self.db_pool.runInteraction(
+        num_rows, finished = await self.db_pool.runInteraction(
             self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
         )
 
         if not finished:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 self.EVENT_SEARCH_ORDER_UPDATE_NAME
             )
 
@@ -347,8 +343,7 @@ class SearchStore(SearchBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super(SearchStore, self).__init__(database, db_conn, hs)
 
-    @defer.inlineCallbacks
-    def search_msgs(self, room_ids, search_term, keys):
+    async def search_msgs(self, room_ids, search_term, keys):
         """Performs a full text search over events with given keys.
 
         Args:
@@ -425,7 +420,7 @@ class SearchStore(SearchBackgroundUpdateStore):
         # entire table from the database.
         sql += " ORDER BY rank DESC LIMIT 500"
 
-        results = yield self.db_pool.execute(
+        results = await self.db_pool.execute(
             "search_msgs", self.db_pool.cursor_to_dict, sql, *args
         )
 
@@ -433,7 +428,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
         # search results (which is a data leak)
-        events = yield self.get_events_as_list(
+        events = await self.get_events_as_list(
             [r["event_id"] for r in results],
             redact_behaviour=EventRedactBehaviour.BLOCK,
         )
@@ -442,11 +437,11 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         highlights = None
         if isinstance(self.database_engine, PostgresEngine):
-            highlights = yield self._find_highlights_in_postgres(search_query, events)
+            highlights = await self._find_highlights_in_postgres(search_query, events)
 
         count_sql += " GROUP BY room_id"
 
-        count_results = yield self.db_pool.execute(
+        count_results = await self.db_pool.execute(
             "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
         )
 
@@ -462,19 +457,25 @@ class SearchStore(SearchBackgroundUpdateStore):
             "count": count,
         }
 
-    @defer.inlineCallbacks
-    def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
+    async def search_rooms(
+        self,
+        room_ids: List[str],
+        search_term: str,
+        keys: List[str],
+        limit,
+        pagination_token: Optional[str] = None,
+    ) -> List[dict]:
         """Performs a full text search over events with given keys.
 
         Args:
-            room_id (list): The room_ids to search in
-            search_term (str): Search term to search for
-            keys (list): List of keys to search in, currently supports
-                "content.body", "content.name", "content.topic"
-            pagination_token (str): A pagination token previously returned
+            room_ids: The room_ids to search in
+            search_term: Search term to search for
+            keys: List of keys to search in, currently supports "content.body",
+                "content.name", "content.topic"
+            pagination_token: A pagination token previously returned
 
         Returns:
-            list of dicts
+            Each match as a dictionary.
         """
         clauses = []
 
@@ -577,7 +578,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         args.append(limit)
 
-        results = yield self.db_pool.execute(
+        results = await self.db_pool.execute(
             "search_rooms", self.db_pool.cursor_to_dict, sql, *args
         )
 
@@ -585,7 +586,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
         # search results (which is a data leak)
-        events = yield self.get_events_as_list(
+        events = await self.get_events_as_list(
             [r["event_id"] for r in results],
             redact_behaviour=EventRedactBehaviour.BLOCK,
         )
@@ -594,11 +595,11 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         highlights = None
         if isinstance(self.database_engine, PostgresEngine):
-            highlights = yield self._find_highlights_in_postgres(search_query, events)
+            highlights = await self._find_highlights_in_postgres(search_query, events)
 
         count_sql += " GROUP BY room_id"
 
-        count_results = yield self.db_pool.execute(
+        count_results = await self.db_pool.execute(
             "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
         )
 
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index dae8e8bd29..be191dd870 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -15,8 +15,6 @@
 
 from unpaddedbase64 import encode_base64
 
-from twisted.internet import defer
-
 from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedList
 
@@ -40,9 +38,8 @@ class SignatureWorkerStore(SQLBaseStore):
 
         return self.db_pool.runInteraction("get_event_reference_hashes", f)
 
-    @defer.inlineCallbacks
-    def add_event_hashes(self, event_ids):
-        hashes = yield self.get_event_reference_hashes(event_ids)
+    async def add_event_hashes(self, event_ids):
+        hashes = await self.get_event_reference_hashes(event_ids)
         hashes = {
             e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
             for e_id, h in hashes.items()
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index d73a8e8ab9..af21fe457a 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -16,8 +16,6 @@
 import logging
 import re
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.state import StateFilter
@@ -59,8 +57,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             "populate_user_directory_cleanup", self._populate_user_directory_cleanup
         )
 
-    @defer.inlineCallbacks
-    def _populate_user_directory_createtables(self, progress, batch_size):
+    async def _populate_user_directory_createtables(self, progress, batch_size):
 
         # Get all the rooms that we want to process.
         def _make_staging_area(txn):
@@ -102,45 +99,43 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
                 self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
 
-        new_pos = yield self.get_max_stream_id_in_current_state_deltas()
-        yield self.db_pool.runInteraction(
+        new_pos = await self.get_max_stream_id_in_current_state_deltas()
+        await self.db_pool.runInteraction(
             "populate_user_directory_temp_build", _make_staging_area
         )
-        yield self.db_pool.simple_insert(
+        await self.db_pool.simple_insert(
             TEMP_TABLE + "_position", {"position": new_pos}
         )
 
-        yield self.db_pool.updates._end_background_update(
+        await self.db_pool.updates._end_background_update(
             "populate_user_directory_createtables"
         )
         return 1
 
-    @defer.inlineCallbacks
-    def _populate_user_directory_cleanup(self, progress, batch_size):
+    async def _populate_user_directory_cleanup(self, progress, batch_size):
         """
         Update the user directory stream position, then clean up the old tables.
         """
-        position = yield self.db_pool.simple_select_one_onecol(
+        position = await self.db_pool.simple_select_one_onecol(
             TEMP_TABLE + "_position", None, "position"
         )
-        yield self.update_user_directory_stream_pos(position)
+        await self.update_user_directory_stream_pos(position)
 
         def _delete_staging_area(txn):
             txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
             txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
             txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
 
-        yield self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "populate_user_directory_cleanup", _delete_staging_area
         )
 
-        yield self.db_pool.updates._end_background_update(
+        await self.db_pool.updates._end_background_update(
             "populate_user_directory_cleanup"
         )
         return 1
 
-    @defer.inlineCallbacks
-    def _populate_user_directory_process_rooms(self, progress, batch_size):
+    async def _populate_user_directory_process_rooms(self, progress, batch_size):
         """
         Args:
             progress (dict)
@@ -151,7 +146,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
         # If we don't have progress filed, delete everything.
         if not progress:
-            yield self.delete_all_from_user_dir()
+            await self.delete_all_from_user_dir()
 
         def _get_next_batch(txn):
             # Only fetch 250 rooms, so we don't fetch too many at once, even
@@ -176,13 +171,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
             return rooms_to_work_on
 
-        rooms_to_work_on = yield self.db_pool.runInteraction(
+        rooms_to_work_on = await self.db_pool.runInteraction(
             "populate_user_directory_temp_read", _get_next_batch
         )
 
         # No more rooms -- complete the transaction.
         if not rooms_to_work_on:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 "populate_user_directory_process_rooms"
             )
             return 1
@@ -195,21 +190,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         processed_event_count = 0
 
         for room_id, event_count in rooms_to_work_on:
-            is_in_room = yield self.is_host_joined(room_id, self.server_name)
+            is_in_room = await self.is_host_joined(room_id, self.server_name)
 
             if is_in_room:
-                is_public = yield self.is_room_world_readable_or_publicly_joinable(
+                is_public = await self.is_room_world_readable_or_publicly_joinable(
                     room_id
                 )
 
-                users_with_profile = yield defer.ensureDeferred(
-                    state.get_current_users_in_room(room_id)
-                )
+                users_with_profile = await state.get_current_users_in_room(room_id)
                 user_ids = set(users_with_profile)
 
                 # Update each user in the user directory.
                 for user_id, profile in users_with_profile.items():
-                    yield self.update_profile_in_user_dir(
+                    await self.update_profile_in_user_dir(
                         user_id, profile.display_name, profile.avatar_url
                     )
 
@@ -223,7 +216,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                         to_insert.add(user_id)
 
                     if to_insert:
-                        yield self.add_users_in_public_rooms(room_id, to_insert)
+                        await self.add_users_in_public_rooms(room_id, to_insert)
                         to_insert.clear()
                 else:
                     for user_id in user_ids:
@@ -243,22 +236,22 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                             # If it gets too big, stop and write to the database
                             # to prevent storing too much in RAM.
                             if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET:
-                                yield self.add_users_who_share_private_room(
+                                await self.add_users_who_share_private_room(
                                     room_id, to_insert
                                 )
                                 to_insert.clear()
 
                     if to_insert:
-                        yield self.add_users_who_share_private_room(room_id, to_insert)
+                        await self.add_users_who_share_private_room(room_id, to_insert)
                         to_insert.clear()
 
             # We've finished a room. Delete it from the table.
-            yield self.db_pool.simple_delete_one(
+            await self.db_pool.simple_delete_one(
                 TEMP_TABLE + "_rooms", {"room_id": room_id}
             )
             # Update the remaining counter.
             progress["remaining"] -= 1
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "populate_user_directory",
                 self.db_pool.updates._background_update_progress_txn,
                 "populate_user_directory_process_rooms",
@@ -273,13 +266,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
         return processed_event_count
 
-    @defer.inlineCallbacks
-    def _populate_user_directory_process_users(self, progress, batch_size):
+    async def _populate_user_directory_process_users(self, progress, batch_size):
         """
         If search_all_users is enabled, add all of the users to the user directory.
         """
         if not self.hs.config.user_directory_search_all_users:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 "populate_user_directory_process_users"
             )
             return 1
@@ -305,13 +297,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
             return users_to_work_on
 
-        users_to_work_on = yield self.db_pool.runInteraction(
+        users_to_work_on = await self.db_pool.runInteraction(
             "populate_user_directory_temp_read", _get_next_batch
         )
 
         # No more users -- complete the transaction.
         if not users_to_work_on:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 "populate_user_directory_process_users"
             )
             return 1
@@ -322,18 +314,18 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
 
         for user_id in users_to_work_on:
-            profile = yield self.get_profileinfo(get_localpart_from_id(user_id))
-            yield self.update_profile_in_user_dir(
+            profile = await self.get_profileinfo(get_localpart_from_id(user_id))
+            await self.update_profile_in_user_dir(
                 user_id, profile.display_name, profile.avatar_url
             )
 
             # We've finished processing a user. Delete it from the table.
-            yield self.db_pool.simple_delete_one(
+            await self.db_pool.simple_delete_one(
                 TEMP_TABLE + "_users", {"user_id": user_id}
             )
             # Update the remaining counter.
             progress["remaining"] -= 1
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "populate_user_directory",
                 self.db_pool.updates._background_update_progress_txn,
                 "populate_user_directory_process_users",
@@ -342,8 +334,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
         return len(users_to_work_on)
 
-    @defer.inlineCallbacks
-    def is_room_world_readable_or_publicly_joinable(self, room_id):
+    async def is_room_world_readable_or_publicly_joinable(self, room_id):
         """Check if the room is either world_readable or publically joinable
         """
 
@@ -353,20 +344,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             (EventTypes.RoomHistoryVisibility, ""),
         )
 
-        current_state_ids = yield self.get_filtered_current_state_ids(
+        current_state_ids = await self.get_filtered_current_state_ids(
             room_id, StateFilter.from_types(types_to_filter)
         )
 
         join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
         if join_rules_id:
-            join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
+            join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
             if join_rule_ev:
                 if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
                     return True
 
         hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
         if hist_vis_id:
-            hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
+            hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
             if hist_vis_ev:
                 if hist_vis_ev.content.get("history_visibility") == "world_readable":
                     return True
@@ -590,19 +581,18 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             "remove_from_user_dir", _remove_from_user_dir_txn
         )
 
-    @defer.inlineCallbacks
-    def get_users_in_dir_due_to_room(self, room_id):
+    async def get_users_in_dir_due_to_room(self, room_id):
         """Get all user_ids that are in the room directory because they're
         in the given room_id
         """
-        user_ids_share_pub = yield self.db_pool.simple_select_onecol(
+        user_ids_share_pub = await self.db_pool.simple_select_onecol(
             table="users_in_public_rooms",
             keyvalues={"room_id": room_id},
             retcol="user_id",
             desc="get_users_in_dir_due_to_room",
         )
 
-        user_ids_share_priv = yield self.db_pool.simple_select_onecol(
+        user_ids_share_priv = await self.db_pool.simple_select_onecol(
             table="users_who_share_private_rooms",
             keyvalues={"room_id": room_id},
             retcol="other_user_id",
@@ -645,8 +635,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             "remove_user_who_share_room", _remove_user_who_share_room_txn
         )
 
-    @defer.inlineCallbacks
-    def get_user_dir_rooms_user_is_in(self, user_id):
+    async def get_user_dir_rooms_user_is_in(self, user_id):
         """
         Returns the rooms that a user is in.
 
@@ -656,14 +645,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         Returns:
             list: user_id
         """
-        rows = yield self.db_pool.simple_select_onecol(
+        rows = await self.db_pool.simple_select_onecol(
             table="users_who_share_private_rooms",
             keyvalues={"user_id": user_id},
             retcol="room_id",
             desc="get_rooms_user_is_in",
         )
 
-        pub_rows = yield self.db_pool.simple_select_onecol(
+        pub_rows = await self.db_pool.simple_select_onecol(
             table="users_in_public_rooms",
             keyvalues={"user_id": user_id},
             retcol="room_id",
@@ -674,32 +663,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         users.update(rows)
         return list(users)
 
-    @defer.inlineCallbacks
-    def get_rooms_in_common_for_users(self, user_id, other_user_id):
-        """Given two user_ids find out the list of rooms they share.
-        """
-        sql = """
-            SELECT room_id FROM (
-                SELECT c.room_id FROM current_state_events AS c
-                INNER JOIN room_memberships AS m USING (event_id)
-                WHERE type = 'm.room.member'
-                    AND m.membership = 'join'
-                    AND state_key = ?
-            ) AS f1 INNER JOIN (
-                SELECT c.room_id FROM current_state_events AS c
-                INNER JOIN room_memberships AS m USING (event_id)
-                WHERE type = 'm.room.member'
-                    AND m.membership = 'join'
-                    AND state_key = ?
-            ) f2 USING (room_id)
-        """
-
-        rows = yield self.db_pool.execute(
-            "get_rooms_in_common_for_users", None, sql, user_id, other_user_id
-        )
-
-        return [room_id for room_id, in rows]
-
     def get_user_directory_stream_pos(self):
         return self.db_pool.simple_select_one_onecol(
             table="user_directory_stream_pos",
@@ -708,8 +671,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             desc="get_user_directory_stream_pos",
         )
 
-    @defer.inlineCallbacks
-    def search_user_dir(self, user_id, search_term, limit):
+    async def search_user_dir(self, user_id, search_term, limit):
         """Searches for users in directory
 
         Returns:
@@ -806,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
 
-        results = yield self.db_pool.execute(
+        results = await self.db_pool.execute(
             "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
         )
 
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 6a545d2eb0..ecfafe68a9 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -40,7 +40,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
     def test_search_user_dir(self):
         # normally when alice searches the directory she should just find
         # bob because bobby doesn't share a room with her.
-        r = yield self.store.search_user_dir(ALICE, "bob", 10)
+        r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
         self.assertFalse(r["limited"])
         self.assertEqual(1, len(r["results"]))
         self.assertDictEqual(
@@ -51,7 +51,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
     def test_search_user_dir_all_users(self):
         self.hs.config.user_directory_search_all_users = True
         try:
-            r = yield self.store.search_user_dir(ALICE, "bob", 10)
+            r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
             self.assertFalse(r["limited"])
             self.assertEqual(2, len(r["results"]))
             self.assertDictEqual(