diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index eddb32b4d3..484875f989 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -15,9 +15,7 @@
import itertools
import logging
from queue import Empty, PriorityQueue
-from typing import Dict, List, Optional, Set, Tuple
-
-from twisted.internet import defer
+from typing import Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -286,17 +284,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return dict(txn)
- @defer.inlineCallbacks
- def get_max_depth_of(self, event_ids):
+ async def get_max_depth_of(self, event_ids: List[str]) -> int:
"""Returns the max depth of a set of event IDs
Args:
- event_ids (list[str])
-
- Returns
- Deferred[int]
+ event_ids: The event IDs to calculate the max depth of.
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
@@ -550,9 +544,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return event_results
- @defer.inlineCallbacks
- def get_missing_events(self, room_id, earliest_events, latest_events, limit):
- ids = yield self.db_pool.runInteraction(
+ async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
+ ids = await self.db_pool.runInteraction(
"get_missing_events",
self._get_missing_events,
room_id,
@@ -560,7 +553,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
latest_events,
limit,
)
- events = yield self.get_events_as_list(ids)
+ events = await self.get_events_as_list(ids)
return events
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
@@ -595,17 +588,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_results.reverse()
return event_results
- @defer.inlineCallbacks
- def get_successor_events(self, event_ids):
+ async def get_successor_events(self, event_ids: Iterable[str]) -> List[str]:
"""Fetch all events that have the given events as a prev event
Args:
- event_ids (iterable[str])
-
- Returns:
- Deferred[list[str]]
+ event_ids: The events to use as the previous events.
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="event_edges",
column="prev_event_id",
iterable=event_ids,
@@ -674,8 +663,7 @@ class EventFederationStore(EventFederationWorkerStore):
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
- @defer.inlineCallbacks
- def _background_delete_non_state_event_auth(self, progress, batch_size):
+ async def _background_delete_non_state_event_auth(self, progress, batch_size):
def delete_event_auth(txn):
target_min_stream_id = progress.get("target_min_stream_id_inclusive")
max_stream_id = progress.get("max_stream_id_exclusive")
@@ -714,12 +702,12 @@ class EventFederationStore(EventFederationWorkerStore):
return min_stream_id >= target_min_stream_id
- result = yield self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
self.EVENT_AUTH_STATE_ONLY, delete_event_auth
)
if not result:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_AUTH_STATE_ONLY
)
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index baa7a5092a..686052bd83 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -15,8 +15,6 @@
import typing
from collections import Counter
-from twisted.internet import defer
-
from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
@@ -69,8 +67,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
self._current_forward_extremities_amount = Counter([x[0] for x in res])
- @defer.inlineCallbacks
- def count_daily_messages(self):
+ async def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
@@ -88,11 +85,9 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(count,) = txn.fetchone()
return count
- ret = yield self.db_pool.runInteraction("count_messages", _count_messages)
- return ret
+ return await self.db_pool.runInteraction("count_messages", _count_messages)
- @defer.inlineCallbacks
- def count_daily_sent_messages(self):
+ async def count_daily_sent_messages(self):
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then thats your own fault.
@@ -109,13 +104,11 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(count,) = txn.fetchone()
return count
- ret = yield self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_daily_sent_messages", _count_messages
)
- return ret
- @defer.inlineCallbacks
- def count_daily_active_rooms(self):
+ async def count_daily_active_rooms(self):
def _count(txn):
sql = """
SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
@@ -126,5 +119,4 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(count,) = txn.fetchone()
return count
- ret = yield self.db_pool.runInteraction("count_daily_active_rooms", _count)
- return ret
+ return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index f618629e09..402ae25571 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,9 +17,8 @@
import logging
import re
-from typing import Optional
+from typing import Dict, List, Optional
-from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.api.constants import UserTypes
@@ -30,7 +29,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
@@ -69,19 +68,15 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_user_by_id",
)
- @defer.inlineCallbacks
- def is_trial_user(self, user_id):
+ async def is_trial_user(self, user_id: str) -> bool:
"""Checks if user is in the "trial" period, i.e. within the first
N days of registration defined by `mau_trial_days` config
Args:
- user_id (str)
-
- Returns:
- Deferred[bool]
+ user_id: The user to check for trial status.
"""
- info = yield self.get_user_by_id(user_id)
+ info = await self.get_user_by_id(user_id)
if not info:
return False
@@ -105,41 +100,42 @@ class RegistrationWorkerStore(SQLBaseStore):
"get_user_by_access_token", self._query_for_auth, token
)
- @cachedInlineCallbacks()
- def get_expiration_ts_for_user(self, user_id):
+ @cached()
+ async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]:
"""Get the expiration timestamp for the account bearing a given user ID.
Args:
- user_id (str): The ID of the user.
+ user_id: The ID of the user.
Returns:
- defer.Deferred: None, if the account has no expiration timestamp,
- otherwise int representation of the timestamp (as a number of
- milliseconds since epoch).
+ None, if the account has no expiration timestamp, otherwise int
+ representation of the timestamp (as a number of milliseconds since epoch).
"""
- res = yield self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="expiration_ts_ms",
allow_none=True,
desc="get_expiration_ts_for_user",
)
- return res
- @defer.inlineCallbacks
- def set_account_validity_for_user(
- self, user_id, expiration_ts, email_sent, renewal_token=None
- ):
+ async def set_account_validity_for_user(
+ self,
+ user_id: str,
+ expiration_ts: int,
+ email_sent: bool,
+ renewal_token: Optional[str] = None,
+ ) -> None:
"""Updates the account validity properties of the given account, with the
given values.
Args:
- user_id (str): ID of the account to update properties for.
- expiration_ts (int): New expiration date, as a timestamp in milliseconds
+ user_id: ID of the account to update properties for.
+ expiration_ts: New expiration date, as a timestamp in milliseconds
since epoch.
- email_sent (bool): True means a renewal email has been sent for this
- account and there's no need to send another one for the current validity
+ email_sent: True means a renewal email has been sent for this account
+ and there's no need to send another one for the current validity
period.
- renewal_token (str): Renewal token the user can use to extend the validity
+ renewal_token: Renewal token the user can use to extend the validity
of their account. Defaults to no token.
"""
@@ -158,75 +154,69 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_expiration_ts_for_user, (user_id,)
)
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"set_account_validity_for_user", set_account_validity_for_user_txn
)
- @defer.inlineCallbacks
- def set_renewal_token_for_user(self, user_id, renewal_token):
+ async def set_renewal_token_for_user(
+ self, user_id: str, renewal_token: str
+ ) -> None:
"""Defines a renewal token for a given user.
Args:
- user_id (str): ID of the user to set the renewal token for.
- renewal_token (str): Random unique string that will be used to renew the
+ user_id: ID of the user to set the renewal token for.
+ renewal_token: Random unique string that will be used to renew the
user's account.
Raises:
StoreError: The provided token is already set for another user.
"""
- yield self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"renewal_token": renewal_token},
desc="set_renewal_token_for_user",
)
- @defer.inlineCallbacks
- def get_user_from_renewal_token(self, renewal_token):
+ async def get_user_from_renewal_token(self, renewal_token: str) -> str:
"""Get a user ID from a renewal token.
Args:
- renewal_token (str): The renewal token to perform the lookup with.
+ renewal_token: The renewal token to perform the lookup with.
Returns:
- defer.Deferred[str]: The ID of the user to which the token belongs.
+ The ID of the user to which the token belongs.
"""
- res = yield self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcol="user_id",
desc="get_user_from_renewal_token",
)
- return res
-
- @defer.inlineCallbacks
- def get_renewal_token_for_user(self, user_id):
+ async def get_renewal_token_for_user(self, user_id: str) -> str:
"""Get the renewal token associated with a given user ID.
Args:
- user_id (str): The user ID to lookup a token for.
+ user_id: The user ID to lookup a token for.
Returns:
- defer.Deferred[str]: The renewal token associated with this user ID.
+ The renewal token associated with this user ID.
"""
- res = yield self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="renewal_token",
desc="get_renewal_token_for_user",
)
- return res
-
- @defer.inlineCallbacks
- def get_users_expiring_soon(self):
+ async def get_users_expiring_soon(self) -> List[Dict[str, int]]:
"""Selects users whose account will expire in the [now, now + renew_at] time
window (see configuration for account_validity for information on what renew_at
refers to).
Returns:
- Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
+ A list of dictionaries mapping user ID to expiration time (in milliseconds).
"""
def select_users_txn(txn, now_ms, renew_at):
@@ -238,53 +228,49 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, values)
return self.db_pool.cursor_to_dict(txn)
- res = yield self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_expiring_soon",
select_users_txn,
self.clock.time_msec(),
self.config.account_validity.renew_at,
)
- return res
-
- @defer.inlineCallbacks
- def set_renewal_mail_status(self, user_id, email_sent):
+ async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
"""Sets or unsets the flag that indicates whether a renewal email has been sent
to the user (and the user hasn't renewed their account yet).
Args:
- user_id (str): ID of the user to set/unset the flag for.
- email_sent (bool): Flag which indicates whether a renewal email has been sent
+ user_id: ID of the user to set/unset the flag for.
+ email_sent: Flag which indicates whether a renewal email has been sent
to this user.
"""
- yield self.db_pool.simple_update_one(
+ await self.db_pool.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"email_sent": email_sent},
desc="set_renewal_mail_status",
)
- @defer.inlineCallbacks
- def delete_account_validity_for_user(self, user_id):
+ async def delete_account_validity_for_user(self, user_id: str) -> None:
"""Deletes the entry for the given user in the account validity table, removing
their expiration date and renewal token.
Args:
- user_id (str): ID of the user to remove from the account validity table.
+ user_id: ID of the user to remove from the account validity table.
"""
- yield self.db_pool.simple_delete_one(
+ await self.db_pool.simple_delete_one(
table="account_validity",
keyvalues={"user_id": user_id},
desc="delete_account_validity_for_user",
)
- async def is_server_admin(self, user):
+ async def is_server_admin(self, user: UserID) -> bool:
"""Determines if a user is an admin of this homeserver.
Args:
- user (UserID): user ID of the user to test
+ user: user ID of the user to test
- Returns (bool):
+ Returns:
true iff the user is a server admin, false otherwise.
"""
res = await self.db_pool.simple_select_one_onecol(
@@ -332,32 +318,31 @@ class RegistrationWorkerStore(SQLBaseStore):
return None
- @cachedInlineCallbacks()
- def is_real_user(self, user_id):
+ @cached()
+ async def is_real_user(self, user_id: str) -> bool:
"""Determines if the user is a real user, ie does not have a 'user_type'.
Args:
- user_id (str): user id to test
+ user_id: user id to test
Returns:
- Deferred[bool]: True if user 'user_type' is null or empty string
+ True if user 'user_type' is null or empty string
"""
- res = yield self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"is_real_user", self.is_real_user_txn, user_id
)
- return res
@cached()
- def is_support_user(self, user_id):
+ async def is_support_user(self, user_id: str) -> bool:
"""Determines if the user is of type UserTypes.SUPPORT
Args:
- user_id (str): user id to test
+ user_id: user id to test
Returns:
- Deferred[bool]: True if user is of type UserTypes.SUPPORT
+ True if user is of type UserTypes.SUPPORT
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"is_support_user", self.is_support_user_txn, user_id
)
@@ -413,8 +398,7 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_user_by_external_id",
)
- @defer.inlineCallbacks
- def count_all_users(self):
+ async def count_all_users(self):
"""Counts all users registered on the homeserver."""
def _count_users(txn):
@@ -424,8 +408,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return rows[0]["users"]
return 0
- ret = yield self.db_pool.runInteraction("count_users", _count_users)
- return ret
+ return await self.db_pool.runInteraction("count_users", _count_users)
def count_daily_user_type(self):
"""
@@ -460,8 +443,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"count_daily_user_type", _count_daily_user_type
)
- @defer.inlineCallbacks
- def count_nonbridged_users(self):
+ async def count_nonbridged_users(self):
def _count_users(txn):
txn.execute(
"""
@@ -472,11 +454,9 @@ class RegistrationWorkerStore(SQLBaseStore):
(count,) = txn.fetchone()
return count
- ret = yield self.db_pool.runInteraction("count_users", _count_users)
- return ret
+ return await self.db_pool.runInteraction("count_users", _count_users)
- @defer.inlineCallbacks
- def count_real_users(self):
+ async def count_real_users(self):
"""Counts all users without a special user_type registered on the homeserver."""
def _count_users(txn):
@@ -486,8 +466,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return rows[0]["users"]
return 0
- ret = yield self.db_pool.runInteraction("count_real_users", _count_users)
- return ret
+ return await self.db_pool.runInteraction("count_real_users", _count_users)
async def generate_user_id(self) -> str:
"""Generate a suitable localpart for a guest user
@@ -537,23 +516,20 @@ class RegistrationWorkerStore(SQLBaseStore):
return ret["user_id"]
return None
- @defer.inlineCallbacks
- def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
- yield self.db_pool.simple_upsert(
+ async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
+ await self.db_pool.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
- @defer.inlineCallbacks
- def user_get_threepids(self, user_id):
- ret = yield self.db_pool.simple_select_list(
+ async def user_get_threepids(self, user_id):
+ return await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
["medium", "address", "validated_at", "added_at"],
"user_get_threepids",
)
- return ret
def user_delete_threepid(self, user_id, medium, address):
return self.db_pool.simple_delete(
@@ -668,18 +644,18 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_id_servers_user_bound",
)
- @cachedInlineCallbacks()
- def get_user_deactivated_status(self, user_id):
+ @cached()
+ async def get_user_deactivated_status(self, user_id: str) -> bool:
"""Retrieve the value for the `deactivated` property for the provided user.
Args:
- user_id (str): The ID of the user to retrieve the status for.
+ user_id: The ID of the user to retrieve the status for.
Returns:
- defer.Deferred(bool): The requested value.
+ True if the user was deactivated, false if the user is still active.
"""
- res = yield self.db_pool.simple_select_one_onecol(
+ res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="deactivated",
@@ -818,8 +794,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
- @defer.inlineCallbacks
- def _background_update_set_deactivated_flag(self, progress, batch_size):
+ async def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
"""
@@ -870,19 +845,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
else:
return False, len(rows)
- end, nb_processed = yield self.db_pool.runInteraction(
+ end, nb_processed = await self.db_pool.runInteraction(
"users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
)
if end:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
"users_set_deactivated_flag"
)
return nb_processed
- @defer.inlineCallbacks
- def _bg_user_threepids_grandfather(self, progress, batch_size):
+ async def _bg_user_threepids_grandfather(self, progress, batch_size):
"""We now track which identity servers a user binds their 3PID to, so
we need to handle the case of existing bindings where we didn't track
this.
@@ -903,11 +877,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
- yield self.db_pool.updates._end_background_update("user_threepids_grandfather")
+ await self.db_pool.updates._end_background_update("user_threepids_grandfather")
return 1
@@ -937,23 +911,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
- @defer.inlineCallbacks
- def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
+ async def add_access_token_to_user(
+ self,
+ user_id: str,
+ token: str,
+ device_id: Optional[str],
+ valid_until_ms: Optional[int],
+ ) -> None:
"""Adds an access token for the given user.
Args:
- user_id (str): The user ID.
- token (str): The new access token to add.
- device_id (str): ID of the device to associate with the access
- token
- valid_until_ms (int|None): when the token is valid until. None for
- no expiry.
+ user_id: The user ID.
+ token: The new access token to add.
+ device_id: ID of the device to associate with the access token
+ valid_until_ms: when the token is valid until. None for no expiry.
Raises:
StoreError if there was a problem adding this.
"""
next_id = self._access_tokens_id_gen.get_next()
- yield self.db_pool.simple_insert(
+ await self.db_pool.simple_insert(
"access_tokens",
{
"id": next_id,
@@ -1097,7 +1074,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- txn.call_after(self.is_guest.invalidate, (user_id,))
def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
@@ -1241,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return self.db_pool.runInteraction("delete_access_token", f)
- @cachedInlineCallbacks()
- def is_guest(self, user_id):
- res = yield self.db_pool.simple_select_one_onecol(
+ @cached()
+ async def is_guest(self, user_id: str) -> bool:
+ res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
@@ -1481,16 +1457,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
self.clock.time_msec(),
)
- @defer.inlineCallbacks
- def set_user_deactivated_status(self, user_id, deactivated):
+ async def set_user_deactivated_status(
+ self, user_id: str, deactivated: bool
+ ) -> None:
"""Set the `deactivated` property for the provided user to the provided value.
Args:
- user_id (str): The ID of the user to set the status for.
- deactivated (bool): The value to set for `deactivated`.
+ user_id: The ID of the user to set the status for.
+ deactivated: The value to set for `deactivated`.
"""
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
@@ -1507,9 +1484,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
+ txn.call_after(self.is_guest.invalidate, (user_id,))
- @defer.inlineCallbacks
- def _set_expiration_date_when_missing(self):
+ async def _set_expiration_date_when_missing(self):
"""
Retrieves the list of registered users that don't have an expiration date, and
adds an expiration date for each of them.
@@ -1533,7 +1510,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, user["name"], use_delta=True
)
- yield self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"get_users_with_no_expiration_date",
select_users_with_no_expiration_date_txn,
)
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index eedd2d96c3..e4e0a0c433 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,14 +15,13 @@
# limitations under the License.
import logging
-from typing import List, Tuple
+from typing import Dict, List, Tuple
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.storage._base import db_to_json
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -30,30 +29,26 @@ logger = logging.getLogger(__name__)
class TagsWorkerStore(AccountDataWorkerStore):
@cached()
- def get_tags_for_user(self, user_id):
+ async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
"""Get all the tags for a user.
Args:
- user_id(str): The user to get the tags for.
+ user_id: The user to get the tags for.
Returns:
- A deferred dict mapping from room_id strings to dicts mapping from
- tag strings to tag content.
+ A mapping from room_id strings to dicts mapping from tag strings to
+ tag content.
"""
- deferred = self.db_pool.simple_select_list(
+ rows = await self.db_pool.simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
- @deferred.addCallback
- def tags_by_room(rows):
- tags_by_room = {}
- for row in rows:
- room_tags = tags_by_room.setdefault(row["room_id"], {})
- room_tags[row["tag"]] = db_to_json(row["content"])
- return tags_by_room
-
- return deferred
+ tags_by_room = {}
+ for row in rows:
+ room_tags = tags_by_room.setdefault(row["room_id"], {})
+ room_tags[row["tag"]] = db_to_json(row["content"])
+ return tags_by_room
async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -127,17 +122,19 @@ class TagsWorkerStore(AccountDataWorkerStore):
return results, upto_token, limited
- @defer.inlineCallbacks
- def get_updated_tags(self, user_id, stream_id):
+ async def get_updated_tags(
+ self, user_id: str, stream_id: int
+ ) -> Dict[str, List[str]]:
"""Get all the tags for the rooms where the tags have changed since the
given version
Args:
user_id(str): The user to get the tags for.
stream_id(int): The earliest update to get for the user.
+
Returns:
- A deferred dict mapping from room_id strings to lists of tag
- strings for all the rooms that changed since the stream_id token.
+ A mapping from room_id strings to lists of tag strings for all the
+ rooms that changed since the stream_id token.
"""
def get_updated_tags_txn(txn):
@@ -155,47 +152,53 @@ class TagsWorkerStore(AccountDataWorkerStore):
if not changed:
return {}
- room_ids = yield self.db_pool.runInteraction(
+ room_ids = await self.db_pool.runInteraction(
"get_updated_tags", get_updated_tags_txn
)
results = {}
if room_ids:
- tags_by_room = yield self.get_tags_for_user(user_id)
+ tags_by_room = await self.get_tags_for_user(user_id)
for room_id in room_ids:
results[room_id] = tags_by_room.get(room_id, {})
return results
- def get_tags_for_room(self, user_id, room_id):
+ async def get_tags_for_room(
+ self, user_id: str, room_id: str
+ ) -> Dict[str, JsonDict]:
"""Get all the tags for the given room
+
Args:
- user_id(str): The user to get tags for
- room_id(str): The room to get tags for
+ user_id: The user to get tags for
+ room_id: The room to get tags for
+
Returns:
- A deferred list of string tags.
+ A mapping of tags to tag content.
"""
- return self.db_pool.simple_select_list(
+ rows = await self.db_pool.simple_select_list(
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
desc="get_tags_for_room",
- ).addCallback(
- lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows}
)
+ return {row["tag"]: db_to_json(row["content"]) for row in rows}
class TagsStore(TagsWorkerStore):
- @defer.inlineCallbacks
- def add_tag_to_room(self, user_id, room_id, tag, content):
+ async def add_tag_to_room(
+ self, user_id: str, room_id: str, tag: str, content: JsonDict
+ ) -> int:
"""Add a tag to a room for a user.
+
Args:
- user_id(str): The user to add a tag for.
- room_id(str): The room to add a tag for.
- tag(str): The tag name to add.
- content(dict): A json object to associate with the tag.
+ user_id: The user to add a tag for.
+ room_id: The room to add a tag for.
+ tag: The tag name to add.
+ content: A json object to associate with the tag.
+
Returns:
- A deferred that completes once the tag has been added.
+ The next account data ID.
"""
content_json = json.dumps(content)
@@ -209,18 +212,17 @@ class TagsStore(TagsWorkerStore):
self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id:
- yield self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
+ await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
- result = self._account_data_id_gen.get_current_token()
- return result
+ return self._account_data_id_gen.get_current_token()
- @defer.inlineCallbacks
- def remove_tag_from_room(self, user_id, room_id, tag):
+ async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
"""Remove a tag from a room for a user.
+
Returns:
- A deferred that completes once the tag has been removed
+ The next account data ID.
"""
def remove_tag_txn(txn, next_id):
@@ -232,21 +234,22 @@ class TagsStore(TagsWorkerStore):
self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id:
- yield self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
+ await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
- result = self._account_data_id_gen.get_current_token()
- return result
+ return self._account_data_id_gen.get_current_token()
- def _update_revision_txn(self, txn, user_id, room_id, next_id):
+ def _update_revision_txn(
+ self, txn, user_id: str, room_id: str, next_id: int
+ ) -> None:
"""Update the latest revision of the tags for the given user and room.
Args:
txn: The database cursor
- user_id(str): The ID of the user.
- room_id(str): The ID of the room.
- next_id(int): The the revision to advance to.
+ user_id: The ID of the user.
+ room_id: The ID of the room.
+ next_id: The the revision to advance to.
"""
txn.call_after(
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 1e2d584098..139085b672 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
@@ -198,8 +196,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
columns=["room_id"],
)
- @defer.inlineCallbacks
- def _background_deduplicate_state(self, progress, batch_size):
+ async def _background_deduplicate_state(self, progress, batch_size):
"""This background update will slowly deduplicate state by reencoding
them as deltas.
"""
@@ -212,7 +209,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
if max_group is None:
- rows = yield self.db_pool.execute(
+ rows = await self.db_pool.execute(
"_background_deduplicate_state",
None,
"SELECT coalesce(max(id), 0) FROM state_groups",
@@ -330,19 +327,18 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
return False, batch_size
- finished, result = yield self.db_pool.runInteraction(
+ finished, result = await self.db_pool.runInteraction(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
)
if finished:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
)
return result * BATCH_SIZE_SCALE_FACTOR
- @defer.inlineCallbacks
- def _background_index_state(self, progress, batch_size):
+ async def _background_index_state(self, progress, batch_size):
def reindex_txn(conn):
conn.rollback()
if isinstance(self.database_engine, PostgresEngine):
@@ -365,9 +361,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
)
txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
- yield self.db_pool.runWithConnection(reindex_txn)
+ await self.db_pool.runWithConnection(reindex_txn)
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.STATE_GROUP_INDEX_UPDATE_NAME
)
|