summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8199.misc1
-rw-r--r--synapse/storage/databases/main/__init__.py50
-rw-r--r--synapse/storage/databases/main/devices.py38
-rw-r--r--synapse/storage/databases/main/events_worker.py48
-rw-r--r--synapse/storage/databases/main/purge_events.py30
-rw-r--r--synapse/storage/databases/main/receipts.py14
-rw-r--r--synapse/storage/databases/main/relations.py103
7 files changed, 147 insertions, 137 deletions
diff --git a/changelog.d/8199.misc b/changelog.d/8199.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8199.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index e6536c8456..99890ffbf3 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -18,7 +18,7 @@
 import calendar
 import logging
 import time
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import PresenceState
 from synapse.config.homeserver import HomeServerConfig
@@ -294,16 +294,16 @@ class DataStore(
 
         return [UserPresenceState(**row) for row in rows]
 
-    def count_daily_users(self):
+    async def count_daily_users(self) -> int:
         """
         Counts the number of users who used this homeserver in the last 24 hours.
         """
         yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_daily_users", self._count_users, yesterday
         )
 
-    def count_monthly_users(self):
+    async def count_monthly_users(self) -> int:
         """
         Counts the number of users who used this homeserver in the last 30 days.
         Note this method is intended for phonehome metrics only and is different
@@ -311,7 +311,7 @@ class DataStore(
         amongst other things, includes a 3 day grace period before a user counts.
         """
         thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_monthly_users", self._count_users, thirty_days_ago
         )
 
@@ -330,15 +330,15 @@ class DataStore(
         (count,) = txn.fetchone()
         return count
 
-    def count_r30_users(self):
+    async def count_r30_users(self) -> Dict[str, int]:
         """
         Counts the number of 30 day retained users, defined as:-
          * Users who have created their accounts more than 30 days ago
          * Where last seen at most 30 days ago
          * Where account creation and last_seen are > 30 days apart
 
-         Returns counts globaly for a given user as well as breaking
-         by platform
+        Returns:
+             A mapping of counts globally as well as broken out by platform.
         """
 
         def _count_r30_users(txn):
@@ -411,7 +411,7 @@ class DataStore(
 
             return results
 
-        return self.db_pool.runInteraction("count_r30_users", _count_r30_users)
+        return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
 
     def _get_start_of_day(self):
         """
@@ -421,7 +421,7 @@ class DataStore(
         today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
         return today_start * 1000
 
-    def generate_user_daily_visits(self):
+    async def generate_user_daily_visits(self) -> None:
         """
         Generates daily visit data for use in cohort/ retention analysis
         """
@@ -476,7 +476,7 @@ class DataStore(
             # frequently
             self._last_user_visit_update = now
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "generate_user_daily_visits", _generate_user_daily_visits
         )
 
@@ -500,22 +500,28 @@ class DataStore(
             desc="get_users",
         )
 
-    def get_users_paginate(
-        self, start, limit, user_id=None, name=None, guests=True, deactivated=False
-    ):
+    async def get_users_paginate(
+        self,
+        start: int,
+        limit: int,
+        user_id: Optional[str] = None,
+        name: Optional[str] = None,
+        guests: bool = True,
+        deactivated: bool = False,
+    ) -> Tuple[List[Dict[str, Any]], int]:
         """Function to retrieve a paginated list of users from
         users list. This will return a json list of users and the
         total number of users matching the filter criteria.
 
         Args:
-            start (int): start number to begin the query from
-            limit (int): number of rows to retrieve
-            user_id (string): search for user_id. ignored if name is not None
-            name (string): search for local part of user_id or display name
-            guests (bool): whether to in include guest users
-            deactivated (bool): whether to include deactivated users
+            start: start number to begin the query from
+            limit: number of rows to retrieve
+            user_id: search for user_id. ignored if name is not None
+            name: search for local part of user_id or display name
+            guests: whether to in include guest users
+            deactivated: whether to include deactivated users
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]], int
+            A tuple of a list of mappings from user to information and a count of total users.
         """
 
         def get_users_paginate_txn(txn):
@@ -558,7 +564,7 @@ class DataStore(
             users = self.db_pool.cursor_to_dict(txn)
             return users, count
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_users_paginate_txn", get_users_paginate_txn
         )
 
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index e8379c73c4..a29157d979 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -313,9 +313,9 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return results
 
-    def _get_last_device_update_for_remote_user(
+    async def _get_last_device_update_for_remote_user(
         self, destination: str, user_id: str, from_stream_id: int
-    ):
+    ) -> int:
         def f(txn):
             prev_sent_id_sql = """
                 SELECT coalesce(max(stream_id), 0) as stream_id
@@ -326,12 +326,16 @@ class DeviceWorkerStore(SQLBaseStore):
             rows = txn.fetchall()
             return rows[0][0]
 
-        return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
+        return await self.db_pool.runInteraction(
+            "get_last_device_update_for_remote_user", f
+        )
 
-    def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
+    async def mark_as_sent_devices_by_remote(
+        self, destination: str, stream_id: int
+    ) -> None:
         """Mark that updates have successfully been sent to the destination.
         """
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "mark_as_sent_devices_by_remote",
             self._mark_as_sent_devices_by_remote_txn,
             destination,
@@ -684,7 +688,7 @@ class DeviceWorkerStore(SQLBaseStore):
             desc="make_remote_user_device_cache_as_stale",
         )
 
-    def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
+    async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
         """Mark that we no longer track device lists for remote user.
         """
 
@@ -698,7 +702,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "mark_remote_user_device_list_as_unsubscribed",
             _mark_remote_user_device_list_as_unsubscribed_txn,
         )
@@ -959,9 +963,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             desc="update_device",
         )
 
-    def update_remote_device_list_cache_entry(
+    async def update_remote_device_list_cache_entry(
         self, user_id: str, device_id: str, content: JsonDict, stream_id: int
-    ):
+    ) -> None:
         """Updates a single device in the cache of a remote user's devicelist.
 
         Note: assumes that we are the only thread that can be updating this user's
@@ -972,11 +976,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             device_id: ID of decivice being updated
             content: new data on this device
             stream_id: the version of the device list
-
-        Returns:
-            Deferred[None]
         """
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "update_remote_device_list_cache_entry",
             self._update_remote_device_list_cache_entry_txn,
             user_id,
@@ -1028,9 +1029,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             lock=False,
         )
 
-    def update_remote_device_list_cache(
+    async def update_remote_device_list_cache(
         self, user_id: str, devices: List[dict], stream_id: int
-    ):
+    ) -> None:
         """Replace the entire cache of the remote user's devices.
 
         Note: assumes that we are the only thread that can be updating this user's
@@ -1040,11 +1041,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             user_id: User to update device list for
             devices: list of device objects supplied over federation
             stream_id: the version of the device list
-
-        Returns:
-            Deferred[None]
         """
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "update_remote_device_list_cache",
             self._update_remote_device_list_cache_txn,
             user_id,
@@ -1054,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
     def _update_remote_device_list_cache_txn(
         self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
-    ):
+    ) -> None:
         self.db_pool.simple_delete_txn(
             txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
         )
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e6247d682d..a7a73cc3d8 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -823,20 +823,24 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_dict
 
-    def _maybe_redact_event_row(self, original_ev, redactions, event_map):
+    def _maybe_redact_event_row(
+        self,
+        original_ev: EventBase,
+        redactions: Iterable[str],
+        event_map: Dict[str, EventBase],
+    ) -> Optional[EventBase]:
         """Given an event object and a list of possible redacting event ids,
         determine whether to honour any of those redactions and if so return a redacted
         event.
 
         Args:
-             original_ev (EventBase):
-             redactions (iterable[str]): list of event ids of potential redaction events
-             event_map (dict[str, EventBase]): other events which have been fetched, in
-                 which we can look up the redaaction events. Map from event id to event.
+             original_ev: The original event.
+             redactions: list of event ids of potential redaction events
+             event_map: other events which have been fetched, in which we can
+                look up the redaaction events. Map from event id to event.
 
         Returns:
-            Deferred[EventBase|None]: if the event should be redacted, a pruned
-                event object. Otherwise, None.
+            If the event should be redacted, a pruned event object. Otherwise, None.
         """
         if original_ev.type == "m.room.create":
             # we choose to ignore redactions of m.room.create events.
@@ -946,17 +950,17 @@ class EventsWorkerStore(SQLBaseStore):
         row = txn.fetchone()
         return row[0] if row else 0
 
-    def get_current_state_event_counts(self, room_id):
+    async def get_current_state_event_counts(self, room_id: str) -> int:
         """
         Gets the current number of state events in a room.
 
         Args:
-            room_id (str)
+            room_id: The room ID to query.
 
         Returns:
-            Deferred[int]
+            The current number of state events.
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_current_state_event_counts",
             self._get_current_state_event_counts_txn,
             room_id,
@@ -991,7 +995,9 @@ class EventsWorkerStore(SQLBaseStore):
         """The current maximum token that events have reached"""
         return self._stream_id_gen.get_current_token()
 
-    def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+    async def get_all_new_forward_event_rows(
+        self, last_id: int, current_id: int, limit: int
+    ) -> List[Tuple]:
         """Returns new events, for the Events replication stream
 
         Args:
@@ -999,7 +1005,7 @@ class EventsWorkerStore(SQLBaseStore):
             current_id: the maximum stream_id to return up to
             limit: the maximum number of rows to return
 
-        Returns: Deferred[List[Tuple]]
+        Returns:
             a list of events stream rows. Each tuple consists of a stream id as
             the first element, followed by fields suitable for casting into an
             EventsStreamRow.
@@ -1020,18 +1026,20 @@ class EventsWorkerStore(SQLBaseStore):
             txn.execute(sql, (last_id, current_id, limit))
             return txn.fetchall()
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_all_new_forward_event_rows", get_all_new_forward_event_rows
         )
 
-    def get_ex_outlier_stream_rows(self, last_id, current_id):
+    async def get_ex_outlier_stream_rows(
+        self, last_id: int, current_id: int
+    ) -> List[Tuple]:
         """Returns de-outliered events, for the Events replication stream
 
         Args:
             last_id: the last stream_id from the previous batch.
             current_id: the maximum stream_id to return up to
 
-        Returns: Deferred[List[Tuple]]
+        Returns:
             a list of events stream rows. Each tuple consists of a stream id as
             the first element, followed by fields suitable for casting into an
             EventsStreamRow.
@@ -1054,7 +1062,7 @@ class EventsWorkerStore(SQLBaseStore):
             txn.execute(sql, (last_id, current_id))
             return txn.fetchall()
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
         )
 
@@ -1226,11 +1234,11 @@ class EventsWorkerStore(SQLBaseStore):
 
         return (int(res["topological_ordering"]), int(res["stream_ordering"]))
 
-    def get_next_event_to_expire(self):
+    async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
         """Retrieve the entry with the lowest expiry timestamp in the event_expiry
         table, or None if there's no more event to expire.
 
-        Returns: Deferred[Optional[Tuple[str, int]]]
+        Returns:
             A tuple containing the event ID as its first element and an expiry timestamp
             as its second one, if there's at least one row in the event_expiry table.
             None otherwise.
@@ -1246,6 +1254,6 @@ class EventsWorkerStore(SQLBaseStore):
 
             return txn.fetchone()
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
         )
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 3526b6fd66..ea833829ae 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Tuple
+from typing import Any, List, Set, Tuple
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore
@@ -25,25 +25,24 @@ logger = logging.getLogger(__name__)
 
 
 class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
-    def purge_history(self, room_id, token, delete_local_events):
+    async def purge_history(
+        self, room_id: str, token: str, delete_local_events: bool
+    ) -> Set[int]:
         """Deletes room history before a certain point
 
         Args:
-            room_id (str):
-
-            token (str): A topological token to delete events before
-
-            delete_local_events (bool):
+            room_id:
+            token: A topological token to delete events before
+            delete_local_events:
                 if True, we will delete local events as well as remote ones
                 (instead of just marking them as outliers and deleting their
                 state groups).
 
         Returns:
-            Deferred[set[int]]: The set of state groups that are referenced by
-            deleted events.
+            The set of state groups that are referenced by deleted events.
         """
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "purge_history",
             self._purge_history_txn,
             room_id,
@@ -283,17 +282,18 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
 
         return referenced_state_groups
 
-    def purge_room(self, room_id):
+    async def purge_room(self, room_id: str) -> List[int]:
         """Deletes all record of a room
 
         Args:
-            room_id (str)
+            room_id
 
         Returns:
-            Deferred[List[int]]: The list of state groups to delete.
+            The list of state groups to delete.
         """
-
-        return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id)
+        return await self.db_pool.runInteraction(
+            "purge_room", self._purge_room_txn, room_id
+        )
 
     def _purge_room_txn(self, txn, room_id):
         # First we fetch all the state groups that should be deleted, before
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 436f22ad2d..4a0d5a320e 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -276,12 +276,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
         }
         return results
 
-    def get_users_sent_receipts_between(self, last_id: int, current_id: int):
+    async def get_users_sent_receipts_between(
+        self, last_id: int, current_id: int
+    ) -> List[str]:
         """Get all users who sent receipts between `last_id` exclusive and
         `current_id` inclusive.
 
         Returns:
-            Deferred[List[str]]
+            The list of users.
         """
 
         if last_id == current_id:
@@ -296,7 +298,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
             return [r[0] for r in txn]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
         )
 
@@ -553,8 +555,10 @@ class ReceiptsStore(ReceiptsWorkerStore):
 
         return stream_id, max_persisted_id
 
-    def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
-        return self.db_pool.runInteraction(
+    async def insert_graph_receipt(
+        self, room_id, receipt_type, user_id, event_ids, data
+    ):
+        return await self.db_pool.runInteraction(
             "insert_graph_receipt",
             self.insert_graph_receipt_txn,
             room_id,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index a9ceffc20e..5cd61547f7 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -34,38 +34,33 @@ logger = logging.getLogger(__name__)
 
 class RelationsWorkerStore(SQLBaseStore):
     @cached(tree=True)
-    def get_relations_for_event(
+    async def get_relations_for_event(
         self,
-        event_id,
-        relation_type=None,
-        event_type=None,
-        aggregation_key=None,
-        limit=5,
-        direction="b",
-        from_token=None,
-        to_token=None,
-    ):
+        event_id: str,
+        relation_type: Optional[str] = None,
+        event_type: Optional[str] = None,
+        aggregation_key: Optional[str] = None,
+        limit: int = 5,
+        direction: str = "b",
+        from_token: Optional[RelationPaginationToken] = None,
+        to_token: Optional[RelationPaginationToken] = None,
+    ) -> PaginationChunk:
         """Get a list of relations for an event, ordered by topological ordering.
 
         Args:
-            event_id (str): Fetch events that relate to this event ID.
-            relation_type (str|None): Only fetch events with this relation
-                type, if given.
-            event_type (str|None): Only fetch events with this event type, if
-                given.
-            aggregation_key (str|None): Only fetch events with this aggregation
-                key, if given.
-            limit (int): Only fetch the most recent `limit` events.
-            direction (str): Whether to fetch the most recent first (`"b"`) or
-                the oldest first (`"f"`).
-            from_token (RelationPaginationToken|None): Fetch rows from the given
-                token, or from the start if None.
-            to_token (RelationPaginationToken|None): Fetch rows up to the given
-                token, or up to the end if None.
+            event_id: Fetch events that relate to this event ID.
+            relation_type: Only fetch events with this relation type, if given.
+            event_type: Only fetch events with this event type, if given.
+            aggregation_key: Only fetch events with this aggregation key, if given.
+            limit: Only fetch the most recent `limit` events.
+            direction: Whether to fetch the most recent first (`"b"`) or the
+                oldest first (`"f"`).
+            from_token: Fetch rows from the given token, or from the start if None.
+            to_token: Fetch rows up to the given token, or up to the end if None.
 
         Returns:
-            Deferred[PaginationChunk]: List of event IDs that match relations
-            requested. The rows are of the form `{"event_id": "..."}`.
+            List of event IDs that match relations requested. The rows are of
+            the form `{"event_id": "..."}`.
         """
 
         where_clause = ["relates_to_id = ?"]
@@ -131,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore):
                 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_recent_references_for_event", _get_recent_references_for_event_txn
         )
 
     @cached(tree=True)
-    def get_aggregation_groups_for_event(
+    async def get_aggregation_groups_for_event(
         self,
-        event_id,
-        event_type=None,
-        limit=5,
-        direction="b",
-        from_token=None,
-        to_token=None,
-    ):
+        event_id: str,
+        event_type: Optional[str] = None,
+        limit: int = 5,
+        direction: str = "b",
+        from_token: Optional[AggregationPaginationToken] = None,
+        to_token: Optional[AggregationPaginationToken] = None,
+    ) -> PaginationChunk:
         """Get a list of annotations on the event, grouped by event type and
         aggregation key, sorted by count.
 
@@ -152,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore):
         on an event.
 
         Args:
-            event_id (str): Fetch events that relate to this event ID.
-            event_type (str|None): Only fetch events with this event type, if
-                given.
-            limit (int): Only fetch the `limit` groups.
-            direction (str): Whether to fetch the highest count first (`"b"`) or
+            event_id: Fetch events that relate to this event ID.
+            event_type: Only fetch events with this event type, if given.
+            limit: Only fetch the `limit` groups.
+            direction: Whether to fetch the highest count first (`"b"`) or
                 the lowest count first (`"f"`).
-            from_token (AggregationPaginationToken|None): Fetch rows from the
-                given token, or from the start if None.
-            to_token (AggregationPaginationToken|None): Fetch rows up to the
-                given token, or up to the end if None.
-
+            from_token: Fetch rows from the given token, or from the start if None.
+            to_token: Fetch rows up to the given token, or up to the end if None.
 
         Returns:
-            Deferred[PaginationChunk]: List of groups of annotations that
-            match. Each row is a dict with `type`, `key` and `count` fields.
+            List of groups of annotations that match. Each row is a dict with
+            `type`, `key` and `count` fields.
         """
 
         where_clause = ["relates_to_id = ?", "relation_type = ?"]
@@ -225,7 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
         )
 
@@ -279,18 +270,20 @@ class RelationsWorkerStore(SQLBaseStore):
 
         return await self.get_event(edit_id, allow_none=True)
 
-    def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
+    async def has_user_annotated_event(
+        self, parent_id: str, event_type: str, aggregation_key: str, sender: str
+    ) -> bool:
         """Check if a user has already annotated an event with the same key
         (e.g. already liked an event).
 
         Args:
-            parent_id (str): The event being annotated
-            event_type (str): The event type of the annotation
-            aggregation_key (str): The aggregation key of the annotation
-            sender (str): The sender of the annotation
+            parent_id: The event being annotated
+            event_type: The event type of the annotation
+            aggregation_key: The aggregation key of the annotation
+            sender: The sender of the annotation
 
         Returns:
-            Deferred[bool]
+            True if the event is already annotated.
         """
 
         sql = """
@@ -319,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
             return bool(txn.fetchone())
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )