diff --git a/changelog.d/8162.misc b/changelog.d/8162.misc
new file mode 100644
index 0000000000..e26764dea1
--- /dev/null
+++ b/changelog.d/8162.misc
@@ -0,0 +1 @@
+ Convert various parts of the codebase to async/await.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index bc327e344e..181c3ec249 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -29,9 +29,11 @@ from typing import (
Tuple,
TypeVar,
Union,
+ overload,
)
from prometheus_client import Histogram
+from typing_extensions import Literal
from twisted.enterprise import adbapi
from twisted.internet import defer
@@ -1020,14 +1022,36 @@ class DatabasePool(object):
return txn.execute_batch(sql, args)
- def simple_select_one(
+ @overload
+ async def simple_select_one(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: Literal[False] = False,
+ desc: str = "simple_select_one",
+ ) -> Dict[str, Any]:
+ ...
+
+ @overload
+ async def simple_select_one(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: Literal[True] = True,
+ desc: str = "simple_select_one",
+ ) -> Optional[Dict[str, Any]]:
+ ...
+
+ async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: bool = False,
desc: str = "simple_select_one",
- ) -> defer.Deferred:
+ ) -> Optional[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
@@ -1038,18 +1062,18 @@ class DatabasePool(object):
allow_none: If true, return None instead of failing if the SELECT
statement returns no rows
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)
- def simple_select_one_onecol(
+ async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: bool = False,
desc: str = "simple_select_one_onecol",
- ) -> defer.Deferred:
+ ) -> Optional[Any]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
@@ -1061,7 +1085,7 @@ class DatabasePool(object):
statement returns no rows
desc: description of the transaction, for logging and metrics
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc,
self.simple_select_one_onecol_txn,
table,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 03b45dbc4d..a811a39eb5 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
- def get_device(self, user_id: str, device_id: str):
+ async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
@@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore):
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
- defer.Deferred for a dict containing the device information
+ A dict containing the device information
Raises:
StoreError: if the device is not found
"""
- return self.db_pool.simple_select_one(
+ return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore):
)
@cached(max_entries=10000)
- def get_device_list_last_stream_id_for_remote(self, user_id: str):
+ async def get_device_list_last_stream_id_for_remote(
+ self, user_id: str
+ ) -> Optional[Any]:
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 037e02603c..301d5d845a 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore):
return RoomAliasMapping(room_id, room_alias.to_string(), servers)
- def get_room_alias_creator(self, room_alias):
- return self.db_pool.simple_select_one_onecol(
+ async def get_room_alias_creator(self, room_alias: str) -> str:
+ return await self.db_pool.simple_select_one_onecol(
table="room_aliases",
keyvalues={"room_alias": room_alias},
retcol="creator",
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 2eeb9f97dc..46c3e33cc6 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -223,15 +223,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return ret
- def count_e2e_room_keys(self, user_id, version):
+ async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
"""Get the number of keys in a backup version.
Args:
- user_id (str): the user whose backup we're querying
- version (str): the version ID of the backup we're querying about
+ user_id: the user whose backup we're querying
+ version: the version ID of the backup we're querying about
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version},
retcol="COUNT(*)",
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e1241a724b..d59d73938a 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -119,19 +119,19 @@ class EventsWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
- def get_received_ts(self, event_id):
+ async def get_received_ts(self, event_id: str) -> Optional[int]:
"""Get received_ts (when it was persisted) for the event.
Raises an exception for unknown events.
Args:
- event_id (str)
+ event_id: The event ID to query.
Returns:
- Deferred[int|None]: Timestamp in milliseconds, or None for events
- that were persisted before received_ts was implemented.
+ Timestamp in milliseconds, or None for events that were persisted
+ before received_ts was implemented.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="received_ts",
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index a488e0924b..c39864f59f 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = ""
class GroupServerWorkerStore(SQLBaseStore):
- def get_group(self, group_id):
- return self.db_pool.simple_select_one(
+ async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="groups",
keyvalues={"group_id": group_id},
retcols=(
@@ -351,8 +351,10 @@ class GroupServerWorkerStore(SQLBaseStore):
)
return bool(result)
- def is_user_admin_in_group(self, group_id, user_id):
- return self.db_pool.simple_select_one_onecol(
+ async def is_user_admin_in_group(
+ self, group_id: str, user_id: str
+ ) -> Optional[bool]:
+ return await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="is_admin",
@@ -360,10 +362,12 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="is_user_admin_in_group",
)
- def is_user_invited_to_local_group(self, group_id, user_id):
+ async def is_user_invited_to_local_group(
+ self, group_id: str, user_id: str
+ ) -> Optional[bool]:
"""Has the group server invited a user?
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 80fc1cd009..4ae255ebd8 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict, Optional
+
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
- def get_local_media(self, media_id):
+ async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media
+
Returns:
None if the media_id doesn't exist.
"""
- return self.db_pool.simple_select_one(
+ return await self.db_pool.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_thumbnail",
)
- def get_cached_remote_media(self, origin, media_id):
- return self.db_pool.simple_select_one(
+ async def get_cached_remote_media(
+ self, origin, media_id: str
+ ) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e71cdd2cb4..fe30552c08 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
return users
@cached(num_args=1)
- def user_last_seen_monthly_active(self, user_id):
+ async def user_last_seen_monthly_active(self, user_id: str) -> int:
"""
- Checks if a given user is part of the monthly active user group
- Arguments:
- user_id (str): user to add/update
- Return:
- Deferred[int] : timestamp since last seen, None if never seen
+ Checks if a given user is part of the monthly active user group
+ Arguments:
+ user_id: user to add/update
+
+ Return:
+ Timestamp since last seen, None if never seen
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="monthly_active_users",
keyvalues={"user_id": user_id},
retcol="timestamp",
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index b8261357d4..b8233c4848 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
@@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo
class ProfileWorkerStore(SQLBaseStore):
- async def get_profileinfo(self, user_localpart):
+ async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
try:
profile = await self.db_pool.simple_select_one(
table="profiles",
@@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
- def get_profile_displayname(self, user_localpart):
- return self.db_pool.simple_select_one_onecol(
+ async def get_profile_displayname(self, user_localpart: str) -> str:
+ return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
desc="get_profile_displayname",
)
- def get_profile_avatar_url(self, user_localpart):
- return self.db_pool.simple_select_one_onecol(
+ async def get_profile_avatar_url(self, user_localpart: str) -> str:
+ return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
desc="get_profile_avatar_url",
)
- def get_from_remote_profile_cache(self, user_id):
- return self.db_pool.simple_select_one(
+ async def get_from_remote_profile_cache(
+ self, user_id: str
+ ) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 6821476ee0..cea5ac9a68 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -71,8 +71,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
@cached(num_args=3)
- def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
- return self.db_pool.simple_select_one_onecol(
+ async def get_last_receipt_event_id_for_user(
+ self, user_id: str, room_id: str, receipt_type: str
+ ) -> Optional[str]:
+ return await self.db_pool.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 321a51cc6a..eced53d470 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,7 +17,7 @@
import logging
import re
-from typing import Awaitable, Dict, List, Optional
+from typing import Any, Awaitable, Dict, List, Optional
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@cached()
- def get_user_by_id(self, user_id):
- return self.db_pool.simple_select_one(
+ async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
@@ -1259,12 +1259,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="del_user_pending_deactivation",
)
- def get_user_pending_deactivation(self):
+ async def get_user_pending_deactivation(self) -> Optional[str]:
"""
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py
index cf9ba51205..1e361aaa9a 100644
--- a/synapse/storage/databases/main/rejections.py
+++ b/synapse/storage/databases/main/rejections.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Optional
from synapse.storage._base import SQLBaseStore
@@ -21,8 +22,8 @@ logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
- def get_rejection_reason(self, event_id):
- return self.db_pool.simple_select_one_onecol(
+ async def get_rejection_reason(self, event_id: str) -> Optional[str]:
+ return await self.db_pool.simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={"event_id": event_id},
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index b3772be2b2..97ecdb16e4 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -73,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore):
self.config = hs.config
- def get_room(self, room_id):
+ async def get_room(self, room_id: str) -> dict:
"""Retrieve a room.
Args:
- room_id (str): The ID of the room to retrieve.
+ room_id: The ID of the room to retrieve.
Returns:
A dict containing the room information, or None if the room is unknown.
"""
- return self.db_pool.simple_select_one(
+ return await self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
@@ -330,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore):
return ret_val
@cached(max_entries=10000)
- def is_room_blocked(self, room_id):
- return self.db_pool.simple_select_one_onecol(
+ async def is_room_blocked(self, room_id: str) -> Optional[bool]:
+ return await self.db_pool.simple_select_one_onecol(
table="blocked_rooms",
keyvalues={"room_id": room_id},
retcol="1",
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 991233a9bc..458f169617 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return event.content.get("canonical_alias")
@cached(max_entries=50000)
- def _get_state_group_for_event(self, event_id):
- return self.db_pool.simple_select_one_onecol(
+ async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
+ return await self.db_pool.simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 802c9019b9..9fe97af56a 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore):
return len(rooms_to_work_on)
- def get_stats_positions(self):
+ async def get_stats_positions(self) -> int:
"""
Returns the stats processor positions.
"""
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="stats_incremental_position",
keyvalues={},
retcol="stream_id",
@@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore):
return slice_list
@cached()
- def get_earliest_token_for_stats(self, stats_type, id):
+ async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
"""
Fetch the "earliest token". This is used by the room stats delta
processor to ignore deltas that have been processed between the
@@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore):
being calculated.
Returns:
- Deferred[int]
+ The earliest token.
"""
table, id_col = TYPE_TO_TABLE[stats_type]
- return self.db_pool.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
"%s_current" % (table,),
keyvalues={id_col: id},
retcol="completed_delta_stream_id",
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index af21fe457a..20cbcd851c 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -15,6 +15,7 @@
import logging
import re
+from typing import Any, Dict, Optional
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool
@@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
@cached()
- def get_user_in_directory(self, user_id):
- return self.db_pool.simple_select_one(
+ async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
@@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
- def get_user_directory_stream_pos(self):
- return self.db_pool.simple_select_one_onecol(
+ async def get_user_directory_stream_pos(self) -> int:
+ return await self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
retcol="stream_id",
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index b609b30d4a..60ebc95f3e 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -71,7 +71,9 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_name(self):
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ )
displayname = yield defer.ensureDeferred(
self.handler.get_displayname(self.frank)
@@ -104,7 +106,12 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.frank.localpart)
+ )
+ ),
+ "Frank",
)
@defer.inlineCallbacks
@@ -112,10 +119,17 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_displayname = False
# Setting displayname for the first time is allowed
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ )
self.assertEquals(
- (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.frank.localpart)
+ )
+ ),
+ "Frank",
)
# Setting displayname a second time is forbidden
@@ -158,7 +172,9 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
yield defer.ensureDeferred(self.store.create_profile("caroline"))
- yield self.store.set_profile_displayname("caroline", "Caroline")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname("caroline", "Caroline")
+ )
response = yield defer.ensureDeferred(
self.query_handlers["profile"](
@@ -170,8 +186,10 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_avatar(self):
- yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.frank.localpart, "http://my.server/me.png"
+ )
)
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
@@ -188,7 +206,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/pic.gif",
)
@@ -202,7 +224,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/me.png",
)
@@ -211,12 +237,18 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
- yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.frank.localpart, "http://my.server/me.png"
+ )
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/me.png",
)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index e01de158e5..834b4a0af6 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -144,9 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_users_in_room = get_users_in_room
- self.datastore.get_user_directory_stream_pos.return_value = (
+ self.datastore.get_user_directory_stream_pos.side_effect = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
- defer.succeed(1)
+ lambda: make_awaitable(1)
)
self.datastore.get_current_state_deltas.return_value = (0, None)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 807cd65dd6..04de0b9dbe 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -35,7 +35,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the new user exists with all provided attributes
self.assertEqual(user_id, "@bob:test")
self.assertTrue(access_token)
- self.assertTrue(self.store.get_user_by_id(user_id))
+ self.assertTrue(self.get_success(self.store.get_user_by_id(user_id)))
# Check that the email was assigned
emails = self.get_success(self.store.user_get_threepids(user_id))
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 13bcac743a..bf22540d99 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -97,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
- value = yield self.datastore.db_pool.simple_select_one_onecol(
- table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+ value = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one_onecol(
+ table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+ )
)
self.assertEquals("Value", value)
@@ -111,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
- ret = yield self.datastore.db_pool.simple_select_one(
- table="tablename",
- keyvalues={"keycol": "TheKey"},
- retcols=["colA", "colB", "colC"],
+ ret = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one(
+ table="tablename",
+ keyvalues={"keycol": "TheKey"},
+ retcols=["colA", "colB", "colC"],
+ )
)
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
@@ -127,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
- ret = yield self.datastore.db_pool.simple_select_one(
- table="tablename",
- keyvalues={"keycol": "Not here"},
- retcols=["colA"],
- allow_none=True,
+ ret = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one(
+ table="tablename",
+ keyvalues={"keycol": "Not here"},
+ retcols=["colA"],
+ allow_none=True,
+ )
)
self.assertFalse(ret)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 87ed8f8cd1..34ae8c9da7 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -38,7 +38,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name")
)
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -111,12 +111,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name 1")
)
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do the update
@@ -127,7 +127,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
# check it worked
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9d5b8aa47d..3fd0a38cf5 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -35,21 +35,34 @@ class ProfileStoreTestCase(unittest.TestCase):
def test_displayname(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
- yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ )
self.assertEquals(
- "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
+ "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.u_frank.localpart)
+ )
+ ),
)
@defer.inlineCallbacks
def test_avatar_url(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
- yield self.store.set_profile_avatar_url(
- self.u_frank.localpart, "http://my.site/here"
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.u_frank.localpart, "http://my.site/here"
+ )
)
self.assertEquals(
"http://my.site/here",
- (yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.u_frank.localpart)
+ )
+ ),
)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 58f827d8d3..70c55cd650 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -53,7 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"user_type": None,
"deactivated": 0,
},
- (yield self.store.get_user_by_id(self.user_id)),
+ (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
)
@defer.inlineCallbacks
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index d07b985a8e..bc8400f240 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -54,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"is_public": True,
},
- (yield self.store.get_room(self.room.to_string())),
+ (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
)
@defer.inlineCallbacks
def test_get_room_unknown_room(self):
- self.assertIsNone((yield self.store.get_room("!uknown:test")),)
+ self.assertIsNone(
+ (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
+ )
@defer.inlineCallbacks
def test_get_room_with_stats(self):
@@ -69,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"public": True,
},
- (yield self.store.get_room_with_stats(self.room.to_string())),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_room_with_stats(self.room.to_string())
+ )
+ ),
)
@defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self):
- self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),)
+ self.assertIsNone(
+ (
+ yield defer.ensureDeferred(
+ self.store.get_room_with_stats("!uknown:test")
+ )
+ ),
+ )
class RoomEventsStoreTestCase(unittest.TestCase):
|