summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/account_data.py59
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py54
-rw-r--r--synapse/storage/databases/main/media_repository.py31
-rw-r--r--synapse/storage/databases/main/search.py15
-rw-r--r--synapse/storage/databases/main/user_directory.py44
5 files changed, 118 insertions, 85 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 04042a2c98..4436b1a83d 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -16,9 +16,7 @@
 
 import abc
 import logging
-from typing import List, Optional, Tuple
-
-from twisted.internet import defer
+from typing import Dict, List, Optional, Tuple
 
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
@@ -58,14 +56,16 @@ class AccountDataWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cached()
-    def get_account_data_for_user(self, user_id):
+    async def get_account_data_for_user(
+        self, user_id: str
+    ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
         """Get all the client account_data for a user.
 
         Args:
-            user_id(str): The user to get the account_data for.
+            user_id: The user to get the account_data for.
         Returns:
-            A deferred pair of a dict of global account_data and a dict
-            mapping from room_id string to per room account_data dicts.
+            A 2-tuple of a dict of global account_data and a dict mapping from
+            room_id string to per room account_data dicts.
         """
 
         def get_account_data_for_user_txn(txn):
@@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
             return global_account_data, by_room
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_account_data_for_user", get_account_data_for_user_txn
         )
 
@@ -120,14 +120,16 @@ class AccountDataWorkerStore(SQLBaseStore):
             return None
 
     @cached(num_args=2)
-    def get_account_data_for_room(self, user_id, room_id):
+    async def get_account_data_for_room(
+        self, user_id: str, room_id: str
+    ) -> Dict[str, JsonDict]:
         """Get all the client account_data for a user for a room.
 
         Args:
-            user_id(str): The user to get the account_data for.
-            room_id(str): The room to get the account_data for.
+            user_id: The user to get the account_data for.
+            room_id: The room to get the account_data for.
         Returns:
-            A deferred dict of the room account_data
+            A dict of the room account_data
         """
 
         def get_account_data_for_room_txn(txn):
@@ -142,21 +144,22 @@ class AccountDataWorkerStore(SQLBaseStore):
                 row["account_data_type"]: db_to_json(row["content"]) for row in rows
             }
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_account_data_for_room", get_account_data_for_room_txn
         )
 
     @cached(num_args=3, max_entries=5000)
-    def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
+    async def get_account_data_for_room_and_type(
+        self, user_id: str, room_id: str, account_data_type: str
+    ) -> Optional[JsonDict]:
         """Get the client account_data of given type for a user for a room.
 
         Args:
-            user_id(str): The user to get the account_data for.
-            room_id(str): The room to get the account_data for.
-            account_data_type (str): The account data type to get.
+            user_id: The user to get the account_data for.
+            room_id: The room to get the account_data for.
+            account_data_type: The account data type to get.
         Returns:
-            A deferred of the room account_data for that type, or None if
-            there isn't any set.
+            The room account_data for that type, or None if there isn't any set.
         """
 
         def get_account_data_for_room_and_type_txn(txn):
@@ -174,7 +177,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
             return db_to_json(content_json) if content_json else None
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
         )
 
@@ -238,12 +241,14 @@ class AccountDataWorkerStore(SQLBaseStore):
             "get_updated_room_account_data", get_updated_room_account_data_txn
         )
 
-    def get_updated_account_data_for_user(self, user_id, stream_id):
+    async def get_updated_account_data_for_user(
+        self, user_id: str, stream_id: int
+    ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
         """Get all the client account_data for a that's changed for a user
 
         Args:
-            user_id(str): The user to get the account_data for.
-            stream_id(int): The point in the stream since which to get updates
+            user_id: The user to get the account_data for.
+            stream_id: The point in the stream since which to get updates
         Returns:
             A deferred pair of a dict of global account_data and a dict
             mapping from room_id string to per room account_data dicts.
@@ -277,9 +282,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             user_id, int(stream_id)
         )
         if not changed:
-            return defer.succeed(({}, {}))
+            return ({}, {})
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
         )
 
@@ -416,7 +421,7 @@ class AccountDataStore(AccountDataWorkerStore):
 
         return self._account_data_id_gen.get_current_token()
 
-    def _update_max_stream_id(self, next_id: int):
+    async def _update_max_stream_id(self, next_id: int) -> None:
         """Update the max stream_id
 
         Args:
@@ -435,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore):
             )
             txn.execute(update_max_id_sql, (next_id, next_id))
 
-        return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
+        await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1ee062e3c4..5a7de44b33 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -34,13 +34,15 @@ if TYPE_CHECKING:
 
 
 class EndToEndKeyWorkerStore(SQLBaseStore):
-    def get_e2e_device_keys_for_federation_query(self, user_id: str):
+    async def get_e2e_device_keys_for_federation_query(
+        self, user_id: str
+    ) -> Tuple[int, List[JsonDict]]:
         """Get all devices (with any device keys) for a user
 
         Returns:
-            Deferred which resolves to (stream_id, devices)
+            (stream_id, devices)
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_e2e_device_keys_for_federation_query",
             self._get_e2e_device_keys_for_federation_query_txn,
             user_id,
@@ -292,10 +294,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         )
 
     @cached(max_entries=10000)
-    def count_e2e_one_time_keys(self, user_id, device_id):
+    async def count_e2e_one_time_keys(
+        self, user_id: str, device_id: str
+    ) -> Dict[str, int]:
         """ Count the number of one time keys the server has for a device
         Returns:
-            Dict mapping from algorithm to number of keys for that algorithm.
+            A mapping from algorithm to number of keys for that algorithm.
         """
 
         def _count_e2e_one_time_keys(txn):
@@ -310,7 +314,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                 result[algorithm] = key_count
             return result
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "count_e2e_one_time_keys", _count_e2e_one_time_keys
         )
 
@@ -348,7 +352,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         list_name="user_ids",
         num_args=1,
     )
-    def _get_bare_e2e_cross_signing_keys_bulk(
+    async def _get_bare_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str]
     ) -> Dict[str, Dict[str, dict]]:
         """Returns the cross-signing keys for a set of users.  The output of this
@@ -356,16 +360,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         the signatures for the calling user need to be fetched.
 
         Args:
-            user_ids (list[str]): the users whose keys are being requested
+            user_ids: the users whose keys are being requested
 
         Returns:
-            dict[str, dict[str, dict]]: mapping from user ID to key type to key
-                data.  If a user's cross-signing keys were not found, either
-                their user ID will not be in the dict, or their user ID will map
-                to None.
+            A mapping from user ID to key type to key data. If a user's cross-signing
+            keys were not found, either their user ID will not be in the dict, or
+            their user ID will map to None.
 
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_bare_e2e_cross_signing_keys_bulk",
             self._get_bare_e2e_cross_signing_keys_bulk_txn,
             user_ids,
@@ -588,7 +591,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
 
 
 class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
-    def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+    async def set_e2e_device_keys(
+        self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
+    ) -> bool:
         """Stores device keys for a device. Returns whether there was a change
         or the keys were already in the database.
         """
@@ -624,12 +629,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             log_kv({"message": "Device keys stored."})
             return True
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "set_e2e_device_keys", _set_e2e_device_keys_txn
         )
 
-    def claim_e2e_one_time_keys(self, query_list):
-        """Take a list of one time keys out of the database"""
+    async def claim_e2e_one_time_keys(
+        self, query_list: Iterable[Tuple[str, str, str]]
+    ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+        """Take a list of one time keys out of the database.
+
+        Args:
+            query_list: An iterable of tuples of (user ID, device ID, algorithm).
+
+        Returns:
+            A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+        """
 
         @trace
         def _claim_e2e_one_time_keys(txn):
@@ -665,11 +679,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 )
             return result
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
         )
 
-    def delete_e2e_keys_by_device(self, user_id, device_id):
+    async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
         def delete_e2e_keys_by_device_txn(txn):
             log_kv(
                 {
@@ -692,7 +706,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
         )
 
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 3919ecad69..86557d5512 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Tuple
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
@@ -93,7 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="mark_local_media_as_safe",
         )
 
-    def get_url_cache(self, url, ts):
+    async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
         """Get the media_id and ts for a cached URL as of the given timestamp
         Returns:
             None if the URL isn't cached.
@@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 )
             )
 
-        return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
+        return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
 
     async def store_url_cache(
         self, url, response_code, etag, expires_ts, og, media_id, download_ts
@@ -237,12 +237,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_cached_remote_media",
         )
 
-    def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+    async def update_cached_last_access_time(
+        self,
+        local_media: Iterable[str],
+        remote_media: Iterable[Tuple[str, str]],
+        time_ms: int,
+    ):
         """Updates the last access time of the given media
 
         Args:
-            local_media (iterable[str]): Set of media_ids
-            remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+            local_media: Set of media_ids
+            remote_media: Set of (server_name, media_id)
             time_ms: Current time in milliseconds
         """
 
@@ -267,7 +272,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
             txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "update_cached_last_access_time", update_cache_txn
         )
 
@@ -325,7 +330,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
         )
 
-    def delete_remote_media(self, media_origin, media_id):
+    async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
         def delete_remote_media_txn(txn):
             self.db_pool.simple_delete_txn(
                 txn,
@@ -338,11 +343,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 keyvalues={"media_origin": media_origin, "media_id": media_id},
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_remote_media", delete_remote_media_txn
         )
 
-    def get_expired_url_cache(self, now_ts):
+    async def get_expired_url_cache(self, now_ts: int) -> List[str]:
         sql = (
             "SELECT media_id FROM local_media_repository_url_cache"
             " WHERE expires_ts < ?"
@@ -354,7 +359,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             txn.execute(sql, (now_ts,))
             return [row[0] for row in txn]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_expired_url_cache", _get_expired_url_cache_txn
         )
 
@@ -371,7 +376,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             "delete_url_cache", _delete_url_cache_txn
         )
 
-    def get_url_cache_media_before(self, before_ts):
+    async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
         sql = (
             "SELECT media_id FROM local_media_repository"
             " WHERE created_ts < ? AND url_cache IS NOT NULL"
@@ -383,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             txn.execute(sql, (before_ts,))
             return [row[0] for row in txn]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_url_cache_media_before", _get_url_cache_media_before_txn
         )
 
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 7f8d1880e5..f01cf2fd02 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -16,9 +16,10 @@
 import logging
 import re
 from collections import namedtuple
-from typing import List, Optional
+from typing import List, Optional, Set
 
 from synapse.api.errors import SynapseError
+from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@@ -620,7 +621,9 @@ class SearchStore(SearchBackgroundUpdateStore):
             "count": count,
         }
 
-    def _find_highlights_in_postgres(self, search_query, events):
+    async def _find_highlights_in_postgres(
+        self, search_query: str, events: List[EventBase]
+    ) -> Set[str]:
         """Given a list of events and a search term, return a list of words
         that match from the content of the event.
 
@@ -628,11 +631,11 @@ class SearchStore(SearchBackgroundUpdateStore):
         highlight the matching parts.
 
         Args:
-            search_query (str)
-            events (list): A list of events
+            search_query
+            events: A list of events
 
         Returns:
-            deferred : A set of strings.
+            A set of strings.
         """
 
         def f(txn):
@@ -685,7 +688,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
             return highlight_words
 
-        return self.db_pool.runInteraction("_find_highlights", f)
+        return await self.db_pool.runInteraction("_find_highlights", f)
 
 
 def _to_postgres_options(options_dict):
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index a9f2e93614..1e96ae7828 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -15,7 +15,7 @@
 
 import logging
 import re
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Iterable, Optional, Tuple
 
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.storage.database import DatabasePool
@@ -365,7 +365,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
         return False
 
-    def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
+    async def update_profile_in_user_dir(
+        self, user_id: str, display_name: str, avatar_url: str
+    ) -> None:
         """
         Update or add a user's profile in the user directory.
         """
@@ -458,17 +460,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
             txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "update_profile_in_user_dir", _update_profile_in_user_dir_txn
         )
 
-    def add_users_who_share_private_room(self, room_id, user_id_tuples):
+    async def add_users_who_share_private_room(
+        self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
+    ) -> None:
         """Insert entries into the users_who_share_private_rooms table. The first
         user should be a local user.
 
         Args:
-            room_id (str)
-            user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+            room_id
+            user_id_tuples: iterable of 2-tuple of user IDs.
         """
 
         def _add_users_who_share_room_txn(txn):
@@ -484,17 +488,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                 value_values=None,
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "add_users_who_share_room", _add_users_who_share_room_txn
         )
 
-    def add_users_in_public_rooms(self, room_id, user_ids):
+    async def add_users_in_public_rooms(
+        self, room_id: str, user_ids: Iterable[str]
+    ) -> None:
         """Insert entries into the users_who_share_private_rooms table. The first
         user should be a local user.
 
         Args:
-            room_id (str)
-            user_ids (list[str])
+            room_id
+            user_ids
         """
 
         def _add_users_in_public_rooms_txn(txn):
@@ -508,11 +514,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                 value_values=None,
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "add_users_in_public_rooms", _add_users_in_public_rooms_txn
         )
 
-    def delete_all_from_user_dir(self):
+    async def delete_all_from_user_dir(self) -> None:
         """Delete the entire user directory
         """
 
@@ -523,7 +529,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             txn.execute("DELETE FROM users_who_share_private_rooms")
             txn.call_after(self.get_user_in_directory.invalidate_all)
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "delete_all_from_user_dir", _delete_all_from_user_dir_txn
         )
 
@@ -555,7 +561,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super(UserDirectoryStore, self).__init__(database, db_conn, hs)
 
-    def remove_from_user_dir(self, user_id):
+    async def remove_from_user_dir(self, user_id: str) -> None:
         def _remove_from_user_dir_txn(txn):
             self.db_pool.simple_delete_txn(
                 txn, table="user_directory", keyvalues={"user_id": user_id}
@@ -578,7 +584,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             )
             txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "remove_from_user_dir", _remove_from_user_dir_txn
         )
 
@@ -605,14 +611,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
 
         return user_ids
 
-    def remove_user_who_share_room(self, user_id, room_id):
+    async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
         """
         Deletes entries in the users_who_share_*_rooms table. The first
         user should be a local user.
 
         Args:
-            user_id (str)
-            room_id (str)
+            user_id
+            room_id
         """
 
         def _remove_user_who_share_room_txn(txn):
@@ -632,7 +638,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
                 keyvalues={"user_id": user_id, "room_id": room_id},
             )
 
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "remove_user_who_share_room", _remove_user_who_share_room_txn
         )