summary refs log tree commit diff
path: root/synapse/storage/databases/main/room.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-09-01 08:39:04 -0400
committerGitHub <noreply@github.com>2020-09-01 08:39:04 -0400
commitda77520cd1c414c9341da287967feb1bab14cbec (patch)
treeada9ea71a1271598a8bc2e9ab7c39a79ca928dac /synapse/storage/databases/main/room.py
parentMake MultiWriterIDGenerator work for streams that use negative stream IDs (#8... (diff)
downloadsynapse-da77520cd1c414c9341da287967feb1bab14cbec.tar.xz
Convert additional databases to async/await part 2 (#8200)
Diffstat (limited to 'synapse/storage/databases/main/room.py')
-rw-r--r--synapse/storage/databases/main/room.py49
1 files changed, 28 insertions, 21 deletions
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index a92641c339..717df97301 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
-    def get_room_with_stats(self, room_id: str):
+    async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
         """Retrieve room with statistics.
 
         Args:
@@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore):
             res["public"] = bool(res["public"])
             return res
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_room_with_stats", get_room_with_stats_txn, room_id
         )
 
@@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore):
             desc="get_public_room_ids",
         )
 
-    def count_public_rooms(self, network_tuple, ignore_non_federatable):
+    async def count_public_rooms(
+        self,
+        network_tuple: Optional[ThirdPartyInstanceID],
+        ignore_non_federatable: bool,
+    ) -> int:
         """Counts the number of public rooms as tracked in the room_stats_current
         and room_stats_state table.
 
         Args:
-            network_tuple (ThirdPartyInstanceID|None)
-            ignore_non_federatable (bool): If true filters out non-federatable rooms
+            network_tuple
+            ignore_non_federatable: If true filters out non-federatable rooms
         """
 
         def _count_public_rooms_txn(txn):
@@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore):
             txn.execute(sql, query_args)
             return txn.fetchone()[0]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_public_rooms", _count_public_rooms_txn
         )
 
@@ -586,15 +590,14 @@ class RoomWorkerStore(SQLBaseStore):
 
         return row
 
-    def get_media_mxcs_in_room(self, room_id):
+    async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
         """Retrieves all the local and remote media MXC URIs in a given room
 
         Args:
-            room_id (str)
+            room_id
 
         Returns:
-            The local and remote media as a lists of tuples where the key is
-            the hostname and the value is the media ID.
+            The local and remote media as a lists of the media IDs.
         """
 
         def _get_media_mxcs_in_room_txn(txn):
@@ -610,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore):
 
             return local_media_mxcs, remote_media_mxcs
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_media_ids_in_room", _get_media_mxcs_in_room_txn
         )
 
-    def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+    async def quarantine_media_ids_in_room(
+        self, room_id: str, quarantined_by: str
+    ) -> int:
         """For a room loops through all events with media and quarantines
         the associated media
         """
@@ -627,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore):
                 txn, local_mxcs, remote_mxcs, quarantined_by
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "quarantine_media_in_room", _quarantine_media_in_room_txn
         )
 
@@ -690,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore):
 
         return local_media_mxcs, remote_media_mxcs
 
-    def quarantine_media_by_id(
+    async def quarantine_media_by_id(
         self, server_name: str, media_id: str, quarantined_by: str,
-    ):
+    ) -> int:
         """quarantines a single local or remote media id
 
         Args:
@@ -711,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore):
                 txn, local_mxcs, remote_mxcs, quarantined_by
             )
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "quarantine_media_by_user", _quarantine_media_by_id_txn
         )
 
-    def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
+    async def quarantine_media_ids_by_user(
+        self, user_id: str, quarantined_by: str
+    ) -> int:
         """quarantines all local media associated with a single user
 
         Args:
@@ -727,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore):
             local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
             return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "quarantine_media_by_user", _quarantine_media_by_user_txn
         )
 
@@ -1284,8 +1291,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             )
         self.hs.get_notifier().on_new_replication_data()
 
-    def get_room_count(self):
-        """Retrieve a list of all rooms
+    async def get_room_count(self) -> int:
+        """Retrieve the total number of rooms.
         """
 
         def f(txn):
@@ -1294,7 +1301,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             row = txn.fetchone()
             return row[0] or 0
 
-        return self.db_pool.runInteraction("get_rooms", f)
+        return await self.db_pool.runInteraction("get_rooms", f)
 
     async def add_event_report(
         self,