summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-11 17:21:20 -0400
committerGitHub <noreply@github.com>2020-08-11 17:21:20 -0400
commit04faa0bfa960d9f0dc60e9cf4ec270221249b7ca (patch)
tree82624acd3b8f965337b423e5901a31d2be19cbb8 /synapse/storage/databases
parentConverts event_federation and registration databases to async/await (#8061) (diff)
downloadsynapse-04faa0bfa960d9f0dc60e9cf4ec270221249b7ca.tar.xz
Convert tags and metrics databases to async/await (#8062)
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/metrics.py20
-rw-r--r--synapse/storage/databases/main/tags.py103
2 files changed, 59 insertions, 64 deletions
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index baa7a5092a..686052bd83 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -15,8 +15,6 @@
 import typing
 from collections import Counter
 
-from twisted.internet import defer
-
 from synapse.metrics import BucketCollector
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore
@@ -69,8 +67,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
         self._current_forward_extremities_amount = Counter([x[0] for x in res])
 
-    @defer.inlineCallbacks
-    def count_daily_messages(self):
+    async def count_daily_messages(self):
         """
         Returns an estimate of the number of messages sent in the last day.
 
@@ -88,11 +85,9 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (count,) = txn.fetchone()
             return count
 
-        ret = yield self.db_pool.runInteraction("count_messages", _count_messages)
-        return ret
+        return await self.db_pool.runInteraction("count_messages", _count_messages)
 
-    @defer.inlineCallbacks
-    def count_daily_sent_messages(self):
+    async def count_daily_sent_messages(self):
         def _count_messages(txn):
             # This is good enough as if you have silly characters in your own
             # hostname then thats your own fault.
@@ -109,13 +104,11 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (count,) = txn.fetchone()
             return count
 
-        ret = yield self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_daily_sent_messages", _count_messages
         )
-        return ret
 
-    @defer.inlineCallbacks
-    def count_daily_active_rooms(self):
+    async def count_daily_active_rooms(self):
         def _count(txn):
             sql = """
                 SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
@@ -126,5 +119,4 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (count,) = txn.fetchone()
             return count
 
-        ret = yield self.db_pool.runInteraction("count_daily_active_rooms", _count)
-        return ret
+        return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index eedd2d96c3..e4e0a0c433 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,14 +15,13 @@
 # limitations under the License.
 
 import logging
-from typing import List, Tuple
+from typing import Dict, List, Tuple
 
 from canonicaljson import json
 
-from twisted.internet import defer
-
 from synapse.storage._base import db_to_json
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -30,30 +29,26 @@ logger = logging.getLogger(__name__)
 
 class TagsWorkerStore(AccountDataWorkerStore):
     @cached()
-    def get_tags_for_user(self, user_id):
+    async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
         """Get all the tags for a user.
 
 
         Args:
-            user_id(str): The user to get the tags for.
+            user_id: The user to get the tags for.
         Returns:
-            A deferred dict mapping from room_id strings to dicts mapping from
-            tag strings to tag content.
+            A mapping from room_id strings to dicts mapping from tag strings to
+            tag content.
         """
 
-        deferred = self.db_pool.simple_select_list(
+        rows = await self.db_pool.simple_select_list(
             "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
         )
 
-        @deferred.addCallback
-        def tags_by_room(rows):
-            tags_by_room = {}
-            for row in rows:
-                room_tags = tags_by_room.setdefault(row["room_id"], {})
-                room_tags[row["tag"]] = db_to_json(row["content"])
-            return tags_by_room
-
-        return deferred
+        tags_by_room = {}
+        for row in rows:
+            room_tags = tags_by_room.setdefault(row["room_id"], {})
+            room_tags[row["tag"]] = db_to_json(row["content"])
+        return tags_by_room
 
     async def get_all_updated_tags(
         self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -127,17 +122,19 @@ class TagsWorkerStore(AccountDataWorkerStore):
 
         return results, upto_token, limited
 
-    @defer.inlineCallbacks
-    def get_updated_tags(self, user_id, stream_id):
+    async def get_updated_tags(
+        self, user_id: str, stream_id: int
+    ) -> Dict[str, List[str]]:
         """Get all the tags for the rooms where the tags have changed since the
         given version
 
         Args:
             user_id(str): The user to get the tags for.
             stream_id(int): The earliest update to get for the user.
+
         Returns:
-            A deferred dict mapping from room_id strings to lists of tag
-            strings for all the rooms that changed since the stream_id token.
+            A mapping from room_id strings to lists of tag strings for all the
+            rooms that changed since the stream_id token.
         """
 
         def get_updated_tags_txn(txn):
@@ -155,47 +152,53 @@ class TagsWorkerStore(AccountDataWorkerStore):
         if not changed:
             return {}
 
-        room_ids = yield self.db_pool.runInteraction(
+        room_ids = await self.db_pool.runInteraction(
             "get_updated_tags", get_updated_tags_txn
         )
 
         results = {}
         if room_ids:
-            tags_by_room = yield self.get_tags_for_user(user_id)
+            tags_by_room = await self.get_tags_for_user(user_id)
             for room_id in room_ids:
                 results[room_id] = tags_by_room.get(room_id, {})
 
         return results
 
-    def get_tags_for_room(self, user_id, room_id):
+    async def get_tags_for_room(
+        self, user_id: str, room_id: str
+    ) -> Dict[str, JsonDict]:
         """Get all the tags for the given room
+
         Args:
-            user_id(str): The user to get tags for
-            room_id(str): The room to get tags for
+            user_id: The user to get tags for
+            room_id: The room to get tags for
+
         Returns:
-            A deferred list of string tags.
+            A mapping of tags to tag content.
         """
-        return self.db_pool.simple_select_list(
+        rows = await self.db_pool.simple_select_list(
             table="room_tags",
             keyvalues={"user_id": user_id, "room_id": room_id},
             retcols=("tag", "content"),
             desc="get_tags_for_room",
-        ).addCallback(
-            lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows}
         )
+        return {row["tag"]: db_to_json(row["content"]) for row in rows}
 
 
 class TagsStore(TagsWorkerStore):
-    @defer.inlineCallbacks
-    def add_tag_to_room(self, user_id, room_id, tag, content):
+    async def add_tag_to_room(
+        self, user_id: str, room_id: str, tag: str, content: JsonDict
+    ) -> int:
         """Add a tag to a room for a user.
+
         Args:
-            user_id(str): The user to add a tag for.
-            room_id(str): The room to add a tag for.
-            tag(str): The tag name to add.
-            content(dict): A json object to associate with the tag.
+            user_id: The user to add a tag for.
+            room_id: The room to add a tag for.
+            tag: The tag name to add.
+            content: A json object to associate with the tag.
+
         Returns:
-            A deferred that completes once the tag has been added.
+            The next account data ID.
         """
         content_json = json.dumps(content)
 
@@ -209,18 +212,17 @@ class TagsStore(TagsWorkerStore):
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
         with self._account_data_id_gen.get_next() as next_id:
-            yield self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
+            await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
 
-        result = self._account_data_id_gen.get_current_token()
-        return result
+        return self._account_data_id_gen.get_current_token()
 
-    @defer.inlineCallbacks
-    def remove_tag_from_room(self, user_id, room_id, tag):
+    async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
         """Remove a tag from a room for a user.
+
         Returns:
-            A deferred that completes once the tag has been removed
+            The next account data ID.
         """
 
         def remove_tag_txn(txn, next_id):
@@ -232,21 +234,22 @@ class TagsStore(TagsWorkerStore):
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
         with self._account_data_id_gen.get_next() as next_id:
-            yield self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
+            await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
 
-        result = self._account_data_id_gen.get_current_token()
-        return result
+        return self._account_data_id_gen.get_current_token()
 
-    def _update_revision_txn(self, txn, user_id, room_id, next_id):
+    def _update_revision_txn(
+        self, txn, user_id: str, room_id: str, next_id: int
+    ) -> None:
         """Update the latest revision of the tags for the given user and room.
 
         Args:
             txn: The database cursor
-            user_id(str): The ID of the user.
-            room_id(str): The ID of the room.
-            next_id(int): The the revision to advance to.
+            user_id: The ID of the user.
+            room_id: The ID of the room.
+            next_id: The the revision to advance to.
         """
 
         txn.call_after(