summary refs log tree commit diff
path: root/synapse/storage/databases/main/relations.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-09-01 09:21:48 -0400
committerGitHub <noreply@github.com>2020-09-01 09:21:48 -0400
commit54f8d73c005cf0401d05fc90e857da253f9d1168 (patch)
tree438e00a60e97c57d150d42c8d884d5ef0fe0f1da /synapse/storage/databases/main/relations.py
parentConvert the well known resolver to async (#8214) (diff)
downloadsynapse-54f8d73c005cf0401d05fc90e857da253f9d1168.tar.xz
Convert additional databases to async/await (#8199)
Diffstat (limited to 'synapse/storage/databases/main/relations.py')
-rw-r--r--synapse/storage/databases/main/relations.py103
1 files changed, 48 insertions, 55 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index a9ceffc20e..5cd61547f7 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -34,38 +34,33 @@ logger = logging.getLogger(__name__)
 
 class RelationsWorkerStore(SQLBaseStore):
     @cached(tree=True)
-    def get_relations_for_event(
+    async def get_relations_for_event(
         self,
-        event_id,
-        relation_type=None,
-        event_type=None,
-        aggregation_key=None,
-        limit=5,
-        direction="b",
-        from_token=None,
-        to_token=None,
-    ):
+        event_id: str,
+        relation_type: Optional[str] = None,
+        event_type: Optional[str] = None,
+        aggregation_key: Optional[str] = None,
+        limit: int = 5,
+        direction: str = "b",
+        from_token: Optional[RelationPaginationToken] = None,
+        to_token: Optional[RelationPaginationToken] = None,
+    ) -> PaginationChunk:
         """Get a list of relations for an event, ordered by topological ordering.
 
         Args:
-            event_id (str): Fetch events that relate to this event ID.
-            relation_type (str|None): Only fetch events with this relation
-                type, if given.
-            event_type (str|None): Only fetch events with this event type, if
-                given.
-            aggregation_key (str|None): Only fetch events with this aggregation
-                key, if given.
-            limit (int): Only fetch the most recent `limit` events.
-            direction (str): Whether to fetch the most recent first (`"b"`) or
-                the oldest first (`"f"`).
-            from_token (RelationPaginationToken|None): Fetch rows from the given
-                token, or from the start if None.
-            to_token (RelationPaginationToken|None): Fetch rows up to the given
-                token, or up to the end if None.
+            event_id: Fetch events that relate to this event ID.
+            relation_type: Only fetch events with this relation type, if given.
+            event_type: Only fetch events with this event type, if given.
+            aggregation_key: Only fetch events with this aggregation key, if given.
+            limit: Only fetch the most recent `limit` events.
+            direction: Whether to fetch the most recent first (`"b"`) or the
+                oldest first (`"f"`).
+            from_token: Fetch rows from the given token, or from the start if None.
+            to_token: Fetch rows up to the given token, or up to the end if None.
 
         Returns:
-            Deferred[PaginationChunk]: List of event IDs that match relations
-            requested. The rows are of the form `{"event_id": "..."}`.
+            List of event IDs that match relations requested. The rows are of
+            the form `{"event_id": "..."}`.
         """
 
         where_clause = ["relates_to_id = ?"]
@@ -131,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore):
                 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_recent_references_for_event", _get_recent_references_for_event_txn
         )
 
     @cached(tree=True)
-    def get_aggregation_groups_for_event(
+    async def get_aggregation_groups_for_event(
         self,
-        event_id,
-        event_type=None,
-        limit=5,
-        direction="b",
-        from_token=None,
-        to_token=None,
-    ):
+        event_id: str,
+        event_type: Optional[str] = None,
+        limit: int = 5,
+        direction: str = "b",
+        from_token: Optional[AggregationPaginationToken] = None,
+        to_token: Optional[AggregationPaginationToken] = None,
+    ) -> PaginationChunk:
         """Get a list of annotations on the event, grouped by event type and
         aggregation key, sorted by count.
 
@@ -152,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore):
         on an event.
 
         Args:
-            event_id (str): Fetch events that relate to this event ID.
-            event_type (str|None): Only fetch events with this event type, if
-                given.
-            limit (int): Only fetch the `limit` groups.
-            direction (str): Whether to fetch the highest count first (`"b"`) or
+            event_id: Fetch events that relate to this event ID.
+            event_type: Only fetch events with this event type, if given.
+            limit: Only fetch the `limit` groups.
+            direction: Whether to fetch the highest count first (`"b"`) or
                 the lowest count first (`"f"`).
-            from_token (AggregationPaginationToken|None): Fetch rows from the
-                given token, or from the start if None.
-            to_token (AggregationPaginationToken|None): Fetch rows up to the
-                given token, or up to the end if None.
-
+            from_token: Fetch rows from the given token, or from the start if None.
+            to_token: Fetch rows up to the given token, or up to the end if None.
 
         Returns:
-            Deferred[PaginationChunk]: List of groups of annotations that
-            match. Each row is a dict with `type`, `key` and `count` fields.
+            List of groups of annotations that match. Each row is a dict with
+            `type`, `key` and `count` fields.
         """
 
         where_clause = ["relates_to_id = ?", "relation_type = ?"]
@@ -225,7 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
         )
 
@@ -279,18 +270,20 @@ class RelationsWorkerStore(SQLBaseStore):
 
         return await self.get_event(edit_id, allow_none=True)
 
-    def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
+    async def has_user_annotated_event(
+        self, parent_id: str, event_type: str, aggregation_key: str, sender: str
+    ) -> bool:
         """Check if a user has already annotated an event with the same key
         (e.g. already liked an event).
 
         Args:
-            parent_id (str): The event being annotated
-            event_type (str): The event type of the annotation
-            aggregation_key (str): The aggregation key of the annotation
-            sender (str): The sender of the annotation
+            parent_id: The event being annotated
+            event_type: The event type of the annotation
+            aggregation_key: The aggregation key of the annotation
+            sender: The sender of the annotation
 
         Returns:
-            Deferred[bool]
+            True if the event is already annotated.
         """
 
         sql = """
@@ -319,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
             return bool(txn.fetchone())
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )