diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 81e64de126..7a5f0bab05 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -136,7 +136,9 @@ class DeviceWorkerStore(SQLBaseStore):
master_key_by_user = {}
self_signing_key_by_user = {}
for user in users:
- cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+ cross_signing_key = yield defer.ensureDeferred(
+ self.get_e2e_cross_signing_key(user, "master")
+ )
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key
@@ -149,8 +151,8 @@ class DeviceWorkerStore(SQLBaseStore):
"device_id": verify_key.version,
}
- cross_signing_key = yield self.get_e2e_cross_signing_key(
- user, "self_signing"
+ cross_signing_key = yield defer.ensureDeferred(
+ self.get_e2e_cross_signing_key(user, "self_signing")
)
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
@@ -246,7 +248,7 @@ class DeviceWorkerStore(SQLBaseStore):
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
- user_id/device_id to update stream_id and the relevent json-encoded
+ user_id/device_id to update stream_id and the relevant json-encoded
opentracing context
Returns:
@@ -599,7 +601,7 @@ class DeviceWorkerStore(SQLBaseStore):
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
- function to get further updatees.
+ function to get further updates.
The updates are a list of 2-tuples of stream ID and the row data
"""
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 7819bfcbb3..037e02603c 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -14,30 +14,29 @@
# limitations under the License.
from collections import namedtuple
-from typing import Optional
-
-from twisted.internet import defer
+from typing import Iterable, Optional
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
+from synapse.types import RoomAlias
from synapse.util.caches.descriptors import cached
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
class DirectoryWorkerStore(SQLBaseStore):
- @defer.inlineCallbacks
- def get_association_from_room_alias(self, room_alias):
- """ Get's the room_id and server list for a given room_alias
+ async def get_association_from_room_alias(
+ self, room_alias: RoomAlias
+ ) -> Optional[RoomAliasMapping]:
+ """Gets the room_id and server list for a given room_alias
Args:
- room_alias (RoomAlias)
+ room_alias: The alias to translate to an ID.
Returns:
- Deferred: results in namedtuple with keys "room_id" and
- "servers" or None if no association can be found
+ The room alias mapping or None if no association can be found.
"""
- room_id = yield self.db_pool.simple_select_one_onecol(
+ room_id = await self.db_pool.simple_select_one_onecol(
"room_aliases",
{"room_alias": room_alias.to_string()},
"room_id",
@@ -48,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore):
if not room_id:
return None
- servers = yield self.db_pool.simple_select_onecol(
+ servers = await self.db_pool.simple_select_onecol(
"room_alias_servers",
{"room_alias": room_alias.to_string()},
"server",
@@ -79,18 +78,20 @@ class DirectoryWorkerStore(SQLBaseStore):
class DirectoryStore(DirectoryWorkerStore):
- @defer.inlineCallbacks
- def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
+ async def create_room_alias_association(
+ self,
+ room_alias: RoomAlias,
+ room_id: str,
+ servers: Iterable[str],
+ creator: Optional[str] = None,
+ ) -> None:
""" Creates an association between a room alias and room_id/servers
Args:
- room_alias (RoomAlias)
- room_id (str)
- servers (list)
- creator (str): Optional user_id of creator.
-
- Returns:
- Deferred
+ room_alias: The alias to create.
+ room_id: The target of the alias.
+ servers: A list of servers through which it may be possible to join the room
+ creator: Optional user_id of creator.
"""
def alias_txn(txn):
@@ -118,24 +119,22 @@ class DirectoryStore(DirectoryWorkerStore):
)
try:
- ret = yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"create_room_alias_association", alias_txn
)
except self.database_engine.module.IntegrityError:
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
)
- return ret
- @defer.inlineCallbacks
- def delete_room_alias(self, room_alias):
- room_id = yield self.db_pool.runInteraction(
+ async def delete_room_alias(self, room_alias: RoomAlias) -> str:
+ room_id = await self.db_pool.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias
)
return room_id
- def _delete_room_alias_txn(self, txn, room_alias):
+ def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str:
txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),),
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index c4aaec3993..2eeb9f97dc 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -23,8 +21,9 @@ from synapse.util import json_encoder
class EndToEndRoomKeyStore(SQLBaseStore):
- @defer.inlineCallbacks
- def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
+ async def update_e2e_room_key(
+ self, user_id, version, room_id, session_id, room_key
+ ):
"""Replaces the encrypted E2E room key for a given session in a given backup
Args:
@@ -37,7 +36,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
StoreError
"""
- yield self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
@@ -54,8 +53,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
desc="update_e2e_room_key",
)
- @defer.inlineCallbacks
- def add_e2e_room_keys(self, user_id, version, room_keys):
+ async def add_e2e_room_keys(self, user_id, version, room_keys):
"""Bulk add room keys to a given backup.
Args:
@@ -88,13 +86,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
}
)
- yield self.db_pool.simple_insert_many(
+ await self.db_pool.simple_insert_many(
table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
)
@trace
- @defer.inlineCallbacks
- def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
@@ -109,7 +106,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
the backup (or for the specified room)
Returns:
- A deferred list of dicts giving the session_data and message metadata for
+ A list of dicts giving the session_data and message metadata for
these room keys.
"""
@@ -124,7 +121,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id:
keyvalues["session_id"] = session_id
- rows = yield self.db_pool.simple_select_list(
+ rows = await self.db_pool.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
@@ -242,8 +239,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@trace
- @defer.inlineCallbacks
- def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def delete_e2e_room_keys(
+ self, user_id, version, room_id=None, session_id=None
+ ):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
@@ -258,7 +256,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
the backup (or for the specified room)
Returns:
- A deferred of the deletion transaction
+ The deletion transaction
"""
keyvalues = {"user_id": user_id, "version": int(version)}
@@ -267,7 +265,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id:
keyvalues["session_id"] = session_id
- yield self.db_pool.simple_delete(
+ await self.db_pool.simple_delete(
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 6126376a6f..f93e0d320d 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,12 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, List, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
from twisted.enterprise.adbapi import Connection
-from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -31,8 +30,7 @@ from synapse.util.iterutils import batch_iter
class EndToEndKeyWorkerStore(SQLBaseStore):
@trace
- @defer.inlineCallbacks
- def get_e2e_device_keys(
+ async def get_e2e_device_keys(
self, query_list, include_all_devices=False, include_deleted_devices=False
):
"""Fetch a list of device keys.
@@ -52,7 +50,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
if not query_list:
return {}
- results = yield self.db_pool.runInteraction(
+ results = await self.db_pool.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_txn,
query_list,
@@ -175,8 +173,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
log_kv(result)
return result
- @defer.inlineCallbacks
- def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
+ async def get_e2e_one_time_keys(
+ self, user_id: str, device_id: str, key_ids: List[str]
+ ) -> Dict[Tuple[str, str], str]:
"""Retrieve a number of one-time keys for a user
Args:
@@ -186,11 +185,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
retrieve
Returns:
- deferred resolving to Dict[(str, str), str]: map from (algorithm,
- key_id) to json string for key
+ A map from (algorithm, key_id) to json string for key
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
iterable=key_ids,
@@ -202,17 +200,21 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
return result
- @defer.inlineCallbacks
- def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
+ async def add_e2e_one_time_keys(
+ self,
+ user_id: str,
+ device_id: str,
+ time_now: int,
+ new_keys: Iterable[Tuple[str, str, str]],
+ ) -> None:
"""Insert some new one time keys for a device. Errors if any of the
keys already exist.
Args:
- user_id(str): id of user to get keys for
- device_id(str): id of device to get keys for
- time_now(long): insertion time to record (ms since epoch)
- new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
- (algorithm, key_id, key json)
+ user_id: id of user to get keys for
+ device_id: id of device to get keys for
+ time_now: insertion time to record (ms since epoch)
+ new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
"""
def _add_e2e_one_time_keys(txn):
@@ -242,7 +244,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
)
@@ -269,22 +271,23 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
- @defer.inlineCallbacks
- def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
+ async def get_e2e_cross_signing_key(
+ self, user_id: str, key_type: str, from_user_id: Optional[str] = None
+ ) -> Optional[dict]:
"""Returns a user's cross-signing key.
Args:
- user_id (str): the user whose key is being requested
- key_type (str): the type of key that is being requested: either 'master'
+ user_id: the user whose key is being requested
+ key_type: the type of key that is being requested: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
- from_user_id (str): if specified, signatures made by this user on
+ from_user_id: if specified, signatures made by this user on
the self-signing key will be included in the result
Returns:
dict of the key data or None if not found
"""
- res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
+ res = await self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
user_keys = res.get(user_id)
if not user_keys:
return None
@@ -450,28 +453,26 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return keys
- @defer.inlineCallbacks
- def get_e2e_cross_signing_keys_bulk(
- self, user_ids: List[str], from_user_id: str = None
- ) -> defer.Deferred:
+ async def get_e2e_cross_signing_keys_bulk(
+ self, user_ids: List[str], from_user_id: Optional[str] = None
+ ) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users.
Args:
- user_ids (list[str]): the users whose keys are being requested
- from_user_id (str): if specified, signatures made by this user on
+ user_ids: the users whose keys are being requested
+ from_user_id: if specified, signatures made by this user on
the self-signing keys will be included in the result
Returns:
- Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
- key data. If a user's cross-signing keys were not found, either
- their user ID will not be in the dict, or their user ID will map
- to None.
+ A map of user ID to key type to key data. If a user's cross-signing
+ keys were not found, either their user ID will not be in the dict,
+ or their user ID will map to None.
"""
- result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
+ result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
if from_user_id:
- result = yield self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn,
result,
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 02b01d9619..e71cdd2cb4 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -15,8 +15,6 @@
import logging
from typing import List
-from twisted.internet import defer
-
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached
@@ -252,16 +250,12 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"reap_monthly_active_users", _reap_users, reserved_users
)
- @defer.inlineCallbacks
- def upsert_monthly_active_user(self, user_id):
+ async def upsert_monthly_active_user(self, user_id: str) -> None:
"""Updates or inserts the user into the monthly active user table, which
is used to track the current MAU usage of the server
Args:
- user_id (str): user to add/update
-
- Returns:
- Deferred
+ user_id: user to add/update
"""
# Support user never to be included in MAU stats. Note I can't easily call this
# from upsert_monthly_active_user_txn because then I need a _txn form of
@@ -271,11 +265,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# _initialise_reserved_users reasoning that it would be very strange to
# include a support user in this context.
- is_support = yield self.is_support_user(user_id)
+ is_support = await self.is_support_user(user_id)
if is_support:
return
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
@@ -322,8 +316,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
return is_insert
- @defer.inlineCallbacks
- def populate_monthly_active_users(self, user_id):
+ async def populate_monthly_active_users(self, user_id):
"""Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables
@@ -332,14 +325,14 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"""
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
- is_guest = yield self.is_guest(user_id)
+ is_guest = await self.is_guest(user_id)
if is_guest:
return
- is_trial = yield self.is_trial_user(user_id)
+ is_trial = await self.is_trial_user(user_id)
if is_trial:
return
- last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
+ last_seen_timestamp = await self.user_last_seen_monthly_active(user_id)
now = self.hs.get_clock().time_msec()
# We want to reduce to the total number of db writes, and are happy
@@ -352,10 +345,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# False, there is no point in checking get_monthly_active_count - it
# adds no value and will break the logic if max_mau_value is exceeded.
if not self._limit_usage_by_mau:
- yield self.upsert_monthly_active_user(user_id)
+ await self.upsert_monthly_active_user(user_id)
else:
- count = yield self.get_monthly_active_count()
+ count = await self.get_monthly_active_count()
if count < self._max_mau_value:
- yield self.upsert_monthly_active_user(user_id)
+ await self.upsert_monthly_active_user(user_id)
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
- yield self.upsert_monthly_active_user(user_id)
+ await self.upsert_monthly_active_user(user_id)
|