diff options
author | Patrick Cloke <clokep@users.noreply.github.com> | 2020-08-26 07:19:32 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-08-26 07:19:32 -0400 |
commit | 4c6c56dc58aba7af92f531655c2355d8f25e529c (patch) | |
tree | f08e4193836fc542f1400fbe407c6323f54171c2 /synapse | |
parent | Fix rate limiting unit tests. (#8167) (diff) | |
download | synapse-4c6c56dc58aba7af92f531655c2355d8f25e529c.tar.xz |
Convert simple_select_one and simple_select_one_onecol to async (#8162)
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/storage/database.py | 36 | ||||
-rw-r--r-- | synapse/storage/databases/main/devices.py | 14 | ||||
-rw-r--r-- | synapse/storage/databases/main/directory.py | 4 | ||||
-rw-r--r-- | synapse/storage/databases/main/e2e_room_keys.py | 8 | ||||
-rw-r--r-- | synapse/storage/databases/main/events_worker.py | 10 | ||||
-rw-r--r-- | synapse/storage/databases/main/group_server.py | 18 | ||||
-rw-r--r-- | synapse/storage/databases/main/media_repository.py | 13 | ||||
-rw-r--r-- | synapse/storage/databases/main/monthly_active_users.py | 15 | ||||
-rw-r--r-- | synapse/storage/databases/main/profile.py | 17 | ||||
-rw-r--r-- | synapse/storage/databases/main/receipts.py | 6 | ||||
-rw-r--r-- | synapse/storage/databases/main/registration.py | 10 | ||||
-rw-r--r-- | synapse/storage/databases/main/rejections.py | 5 | ||||
-rw-r--r-- | synapse/storage/databases/main/room.py | 10 | ||||
-rw-r--r-- | synapse/storage/databases/main/state.py | 4 | ||||
-rw-r--r-- | synapse/storage/databases/main/stats.py | 10 | ||||
-rw-r--r-- | synapse/storage/databases/main/user_directory.py | 9 |
16 files changed, 116 insertions, 73 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py index bc327e344e..181c3ec249 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -29,9 +29,11 @@ from typing import ( Tuple, TypeVar, Union, + overload, ) from prometheus_client import Histogram +from typing_extensions import Literal from twisted.enterprise import adbapi from twisted.internet import defer @@ -1020,14 +1022,36 @@ class DatabasePool(object): return txn.execute_batch(sql, args) - def simple_select_one( + @overload + async def simple_select_one( + self, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: Literal[False] = False, + desc: str = "simple_select_one", + ) -> Dict[str, Any]: + ... + + @overload + async def simple_select_one( + self, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: Literal[True] = True, + desc: str = "simple_select_one", + ) -> Optional[Dict[str, Any]]: + ... + + async def simple_select_one( self, table: str, keyvalues: Dict[str, Any], retcols: Iterable[str], allow_none: bool = False, desc: str = "simple_select_one", - ) -> defer.Deferred: + ) -> Optional[Dict[str, Any]]: """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -1038,18 +1062,18 @@ class DatabasePool(object): allow_none: If true, return None instead of failing if the SELECT statement returns no rows """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none ) - def simple_select_one_onecol( + async def simple_select_one_onecol( self, table: str, keyvalues: Dict[str, Any], retcol: Iterable[str], allow_none: bool = False, desc: str = "simple_select_one_onecol", - ) -> defer.Deferred: + ) -> Optional[Any]: """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it. @@ -1061,7 +1085,7 @@ class DatabasePool(object): statement returns no rows desc: description of the transaction, for logging and metrics """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_one_onecol_txn, table, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 03b45dbc4d..a811a39eb5 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): - def get_device(self, user_id: str, device_id: str): + async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: """Retrieve a device. Only returns devices that are not marked as hidden. @@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore): user_id: The ID of the user which owns the device device_id: The ID of the device to retrieve Returns: - defer.Deferred for a dict containing the device information + A dict containing the device information Raises: StoreError: if the device is not found """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore): ) @cached(max_entries=10000) - def get_device_list_last_stream_id_for_remote(self, user_id: str): + async def get_device_list_last_stream_id_for_remote( + self, user_id: str + ) -> Optional[Any]: """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, retcol="stream_id", diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 037e02603c..301d5d845a 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore): return RoomAliasMapping(room_id, room_alias.to_string(), servers) - def get_room_alias_creator(self, room_alias): - return self.db_pool.simple_select_one_onecol( + async def get_room_alias_creator(self, room_alias: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="room_aliases", keyvalues={"room_alias": room_alias}, retcol="creator", diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 2eeb9f97dc..46c3e33cc6 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -223,15 +223,15 @@ class EndToEndRoomKeyStore(SQLBaseStore): return ret - def count_e2e_room_keys(self, user_id, version): + async def count_e2e_room_keys(self, user_id: str, version: str) -> int: """Get the number of keys in a backup version. Args: - user_id (str): the user whose backup we're querying - version (str): the version ID of the backup we're querying about + user_id: the user whose backup we're querying + version: the version ID of the backup we're querying about """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="e2e_room_keys", keyvalues={"user_id": user_id, "version": version}, retcol="COUNT(*)", diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e1241a724b..d59d73938a 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -119,19 +119,19 @@ class EventsWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) - def get_received_ts(self, event_id): + async def get_received_ts(self, event_id: str) -> Optional[int]: """Get received_ts (when it was persisted) for the event. Raises an exception for unknown events. Args: - event_id (str) + event_id: The event ID to query. Returns: - Deferred[int|None]: Timestamp in milliseconds, or None for events - that were persisted before received_ts was implemented. + Timestamp in milliseconds, or None for events that were persisted + before received_ts was implemented. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="received_ts", diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index a488e0924b..c39864f59f 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json @@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = "" class GroupServerWorkerStore(SQLBaseStore): - def get_group(self, group_id): - return self.db_pool.simple_select_one( + async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="groups", keyvalues={"group_id": group_id}, retcols=( @@ -351,8 +351,10 @@ class GroupServerWorkerStore(SQLBaseStore): ) return bool(result) - def is_user_admin_in_group(self, group_id, user_id): - return self.db_pool.simple_select_one_onecol( + async def is_user_admin_in_group( + self, group_id: str, user_id: str + ) -> Optional[bool]: + return await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="is_admin", @@ -360,10 +362,12 @@ class GroupServerWorkerStore(SQLBaseStore): desc="is_user_admin_in_group", ) - def is_user_invited_to_local_group(self, group_id, user_id): + async def is_user_invited_to_local_group( + self, group_id: str, user_id: str + ) -> Optional[bool]: """Has the group server invited a user? """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 80fc1cd009..4ae255ebd8 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Optional + from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(MediaRepositoryStore, self).__init__(database, db_conn, hs) - def get_local_media(self, media_id): + async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: """Get the metadata for a local piece of media + Returns: None if the media_id doesn't exist. """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_local_thumbnail", ) - def get_cached_remote_media(self, origin, media_id): - return self.db_pool.simple_select_one( + async def get_cached_remote_media( + self, origin, media_id: str + ) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e71cdd2cb4..fe30552c08 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): return users @cached(num_args=1) - def user_last_seen_monthly_active(self, user_id): + async def user_last_seen_monthly_active(self, user_id: str) -> int: """ - Checks if a given user is part of the monthly active user group - Arguments: - user_id (str): user to add/update - Return: - Deferred[int] : timestamp since last seen, None if never seen + Checks if a given user is part of the monthly active user group + Arguments: + user_id: user to add/update + + Return: + Timestamp since last seen, None if never seen """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="monthly_active_users", keyvalues={"user_id": user_id}, retcol="timestamp", diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index b8261357d4..b8233c4848 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore @@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo class ProfileWorkerStore(SQLBaseStore): - async def get_profileinfo(self, user_localpart): + async def get_profileinfo(self, user_localpart: str) -> ProfileInfo: try: profile = await self.db_pool.simple_select_one( table="profiles", @@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore): avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) - def get_profile_displayname(self, user_localpart): - return self.db_pool.simple_select_one_onecol( + async def get_profile_displayname(self, user_localpart: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="displayname", desc="get_profile_displayname", ) - def get_profile_avatar_url(self, user_localpart): - return self.db_pool.simple_select_one_onecol( + async def get_profile_avatar_url(self, user_localpart: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="avatar_url", desc="get_profile_avatar_url", ) - def get_from_remote_profile_cache(self, user_id): - return self.db_pool.simple_select_one( + async def get_from_remote_profile_cache( + self, user_id: str + ) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, retcols=("displayname", "avatar_url"), diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 6821476ee0..cea5ac9a68 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -71,8 +71,10 @@ class ReceiptsWorkerStore(SQLBaseStore): ) @cached(num_args=3) - def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self.db_pool.simple_select_one_onecol( + async def get_last_receipt_event_id_for_user( + self, user_id: str, room_id: str, receipt_type: str + ) -> Optional[str]: + return await self.db_pool.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 321a51cc6a..eced53d470 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,7 +17,7 @@ import logging import re -from typing import Awaitable, Dict, List, Optional +from typing import Any, Awaitable, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore): ) @cached() - def get_user_by_id(self, user_id): - return self.db_pool.simple_select_one( + async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ @@ -1259,12 +1259,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="del_user_pending_deactivation", ) - def get_user_pending_deactivation(self): + async def get_user_pending_deactivation(self) -> Optional[str]: """ Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py index cf9ba51205..1e361aaa9a 100644 --- a/synapse/storage/databases/main/rejections.py +++ b/synapse/storage/databases/main/rejections.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Optional from synapse.storage._base import SQLBaseStore @@ -21,8 +22,8 @@ logger = logging.getLogger(__name__) class RejectionsStore(SQLBaseStore): - def get_rejection_reason(self, event_id): - return self.db_pool.simple_select_one_onecol( + async def get_rejection_reason(self, event_id: str) -> Optional[str]: + return await self.db_pool.simple_select_one_onecol( table="rejections", retcol="reason", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index b3772be2b2..97ecdb16e4 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -73,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore): self.config = hs.config - def get_room(self, room_id): + async def get_room(self, room_id: str) -> dict: """Retrieve a room. Args: - room_id (str): The ID of the room to retrieve. + room_id: The ID of the room to retrieve. Returns: A dict containing the room information, or None if the room is unknown. """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, retcols=("room_id", "is_public", "creator"), @@ -330,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore): return ret_val @cached(max_entries=10000) - def is_room_blocked(self, room_id): - return self.db_pool.simple_select_one_onecol( + async def is_room_blocked(self, room_id: str) -> Optional[bool]: + return await self.db_pool.simple_select_one_onecol( table="blocked_rooms", keyvalues={"room_id": room_id}, retcol="1", diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 991233a9bc..458f169617 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return event.content.get("canonical_alias") @cached(max_entries=50000) - def _get_state_group_for_event(self, event_id): - return self.db_pool.simple_select_one_onecol( + async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: + return await self.db_pool.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, retcol="state_group", diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 802c9019b9..9fe97af56a 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore): return len(rooms_to_work_on) - def get_stats_positions(self): + async def get_stats_positions(self) -> int: """ Returns the stats processor positions. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="stats_incremental_position", keyvalues={}, retcol="stream_id", @@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore): return slice_list @cached() - def get_earliest_token_for_stats(self, stats_type, id): + async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int: """ Fetch the "earliest token". This is used by the room stats delta processor to ignore deltas that have been processed between the @@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore): being calculated. Returns: - Deferred[int] + The earliest token. """ table, id_col = TYPE_TO_TABLE[stats_type] - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( "%s_current" % (table,), keyvalues={id_col: id}, retcol="completed_delta_stream_id", diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index af21fe457a..20cbcd851c 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -15,6 +15,7 @@ import logging import re +from typing import Any, Dict, Optional from synapse.api.constants import EventTypes, JoinRules from synapse.storage.database import DatabasePool @@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) @cached() - def get_user_in_directory(self, user_id): - return self.db_pool.simple_select_one( + async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, retcols=("display_name", "avatar_url"), @@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - def get_user_directory_stream_pos(self): - return self.db_pool.simple_select_one_onecol( + async def get_user_directory_stream_pos(self) -> int: + return await self.db_pool.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, retcol="stream_id", |