From a7bdf98d01d2225a479753a85ba81adf02b16a32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2020 21:38:57 +0100 Subject: Rename database classes to make some sense (#8033) --- synapse/storage/databases/main/group_server.py | 1297 ++++++++++++++++++++++++ 1 file changed, 1297 insertions(+) create mode 100644 synapse/storage/databases/main/group_server.py (limited to 'synapse/storage/databases/main/group_server.py') diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py new file mode 100644 index 0000000000..a98181f445 --- /dev/null +++ b/synapse/storage/databases/main/group_server.py @@ -0,0 +1,1297 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore, db_to_json + +# The category ID for the "default" category. We don't store as null in the +# database to avoid the fun of null != null +_DEFAULT_CATEGORY_ID = "" +_DEFAULT_ROLE_ID = "" + + +class GroupServerWorkerStore(SQLBaseStore): + def get_group(self, group_id): + return self.db_pool.simple_select_one( + table="groups", + keyvalues={"group_id": group_id}, + retcols=( + "name", + "short_description", + "long_description", + "avatar_url", + "is_public", + "join_policy", + ), + allow_none=True, + desc="get_group", + ) + + def get_users_in_group(self, group_id, include_private=False): + # TODO: Pagination + + keyvalues = {"group_id": group_id} + if not include_private: + keyvalues["is_public"] = True + + return self.db_pool.simple_select_list( + table="group_users", + keyvalues=keyvalues, + retcols=("user_id", "is_public", "is_admin"), + desc="get_users_in_group", + ) + + def get_invited_users_in_group(self, group_id): + # TODO: Pagination + + return self.db_pool.simple_select_onecol( + table="group_invites", + keyvalues={"group_id": group_id}, + retcol="user_id", + desc="get_invited_users_in_group", + ) + + def get_rooms_in_group(self, group_id: str, include_private: bool = False): + """Retrieve the rooms that belong to a given group. Does not return rooms that + lack members. + + Args: + group_id: The ID of the group to query for rooms + include_private: Whether to return private rooms in results + + Returns: + Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the + form of: + + { + "room_id": "!a_room_id:example.com", # The ID of the room + "is_public": False # Whether this is a public room or not + } + """ + # TODO: Pagination + + def _get_rooms_in_group_txn(txn): + sql = """ + SELECT room_id, is_public FROM group_rooms + WHERE group_id = ? + AND room_id IN ( + SELECT group_rooms.room_id FROM group_rooms + LEFT JOIN room_stats_current ON + group_rooms.room_id = room_stats_current.room_id + AND joined_members > 0 + AND local_users_in_room > 0 + LEFT JOIN rooms ON + group_rooms.room_id = rooms.room_id + AND (room_version <> '') = ? + ) + """ + args = [group_id, False] + + if not include_private: + sql += " AND is_public = ?" + args += [True] + + txn.execute(sql, args) + + return [ + {"room_id": room_id, "is_public": is_public} + for room_id, is_public in txn + ] + + return self.db_pool.runInteraction( + "get_rooms_in_group", _get_rooms_in_group_txn + ) + + def get_rooms_for_summary_by_category( + self, group_id: str, include_private: bool = False, + ): + """Get the rooms and categories that should be included in a summary request + + Args: + group_id: The ID of the group to query the summary for + include_private: Whether to return private rooms in results + + Returns: + Deferred[Tuple[List, Dict]]: A tuple containing: + + * A list of dictionaries with the keys: + * "room_id": str, the room ID + * "is_public": bool, whether the room is public + * "category_id": str|None, the category ID if set, else None + * "order": int, the sort order of rooms + + * A dictionary with the key: + * category_id (str): a dictionary with the keys: + * "is_public": bool, whether the category is public + * "profile": str, the category profile + * "order": int, the sort order of rooms in this category + """ + + def _get_rooms_for_summary_txn(txn): + keyvalues = {"group_id": group_id} + if not include_private: + keyvalues["is_public"] = True + + sql = """ + SELECT room_id, is_public, category_id, room_order + FROM group_summary_rooms + WHERE group_id = ? + AND room_id IN ( + SELECT group_rooms.room_id FROM group_rooms + LEFT JOIN room_stats_current ON + group_rooms.room_id = room_stats_current.room_id + AND joined_members > 0 + AND local_users_in_room > 0 + LEFT JOIN rooms ON + group_rooms.room_id = rooms.room_id + AND (room_version <> '') = ? + ) + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, False, True)) + else: + txn.execute(sql, (group_id, False)) + + rooms = [ + { + "room_id": row[0], + "is_public": row[1], + "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None, + "order": row[3], + } + for row in txn + ] + + sql = """ + SELECT category_id, is_public, profile, cat_order + FROM group_summary_room_categories + INNER JOIN group_room_categories USING (group_id, category_id) + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + categories = { + row[0]: { + "is_public": row[1], + "profile": db_to_json(row[2]), + "order": row[3], + } + for row in txn + } + + return rooms, categories + + return self.db_pool.runInteraction( + "get_rooms_for_summary", _get_rooms_for_summary_txn + ) + + @defer.inlineCallbacks + def get_group_categories(self, group_id): + rows = yield self.db_pool.simple_select_list( + table="group_room_categories", + keyvalues={"group_id": group_id}, + retcols=("category_id", "is_public", "profile"), + desc="get_group_categories", + ) + + return { + row["category_id"]: { + "is_public": row["is_public"], + "profile": db_to_json(row["profile"]), + } + for row in rows + } + + @defer.inlineCallbacks + def get_group_category(self, group_id, category_id): + category = yield self.db_pool.simple_select_one( + table="group_room_categories", + keyvalues={"group_id": group_id, "category_id": category_id}, + retcols=("is_public", "profile"), + desc="get_group_category", + ) + + category["profile"] = db_to_json(category["profile"]) + + return category + + @defer.inlineCallbacks + def get_group_roles(self, group_id): + rows = yield self.db_pool.simple_select_list( + table="group_roles", + keyvalues={"group_id": group_id}, + retcols=("role_id", "is_public", "profile"), + desc="get_group_roles", + ) + + return { + row["role_id"]: { + "is_public": row["is_public"], + "profile": db_to_json(row["profile"]), + } + for row in rows + } + + @defer.inlineCallbacks + def get_group_role(self, group_id, role_id): + role = yield self.db_pool.simple_select_one( + table="group_roles", + keyvalues={"group_id": group_id, "role_id": role_id}, + retcols=("is_public", "profile"), + desc="get_group_role", + ) + + role["profile"] = db_to_json(role["profile"]) + + return role + + def get_local_groups_for_room(self, room_id): + """Get all of the local group that contain a given room + Args: + room_id (str): The ID of a room + Returns: + Deferred[list[str]]: A twisted.Deferred containing a list of group ids + containing this room + """ + return self.db_pool.simple_select_onecol( + table="group_rooms", + keyvalues={"room_id": room_id}, + retcol="group_id", + desc="get_local_groups_for_room", + ) + + def get_users_for_summary_by_role(self, group_id, include_private=False): + """Get the users and roles that should be included in a summary request + + Returns ([users], [roles]) + """ + + def _get_users_for_summary_txn(txn): + keyvalues = {"group_id": group_id} + if not include_private: + keyvalues["is_public"] = True + + sql = """ + SELECT user_id, is_public, role_id, user_order + FROM group_summary_users + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + users = [ + { + "user_id": row[0], + "is_public": row[1], + "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None, + "order": row[3], + } + for row in txn + ] + + sql = """ + SELECT role_id, is_public, profile, role_order + FROM group_summary_roles + INNER JOIN group_roles USING (group_id, role_id) + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + roles = { + row[0]: { + "is_public": row[1], + "profile": db_to_json(row[2]), + "order": row[3], + } + for row in txn + } + + return users, roles + + return self.db_pool.runInteraction( + "get_users_for_summary_by_role", _get_users_for_summary_txn + ) + + def is_user_in_group(self, user_id, group_id): + return self.db_pool.simple_select_one_onecol( + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="is_user_in_group", + ).addCallback(lambda r: bool(r)) + + def is_user_admin_in_group(self, group_id, user_id): + return self.db_pool.simple_select_one_onecol( + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="is_admin", + allow_none=True, + desc="is_user_admin_in_group", + ) + + def is_user_invited_to_local_group(self, group_id, user_id): + """Has the group server invited a user? + """ + return self.db_pool.simple_select_one_onecol( + table="group_invites", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="user_id", + desc="is_user_invited_to_local_group", + allow_none=True, + ) + + def get_users_membership_info_in_group(self, group_id, user_id): + """Get a dict describing the membership of a user in a group. + + Example if joined: + + { + "membership": "join", + "is_public": True, + "is_privileged": False, + } + + Returns an empty dict if the user is not join/invite/etc + """ + + def _get_users_membership_in_group_txn(txn): + row = self.db_pool.simple_select_one_txn( + txn, + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcols=("is_admin", "is_public"), + allow_none=True, + ) + + if row: + return { + "membership": "join", + "is_public": row["is_public"], + "is_privileged": row["is_admin"], + } + + row = self.db_pool.simple_select_one_onecol_txn( + txn, + table="group_invites", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="user_id", + allow_none=True, + ) + + if row: + return {"membership": "invite"} + + return {} + + return self.db_pool.runInteraction( + "get_users_membership_info_in_group", _get_users_membership_in_group_txn + ) + + def get_publicised_groups_for_user(self, user_id): + """Get all groups a user is publicising + """ + return self.db_pool.simple_select_onecol( + table="local_group_membership", + keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, + retcol="group_id", + desc="get_publicised_groups_for_user", + ) + + def get_attestations_need_renewals(self, valid_until_ms): + """Get all attestations that need to be renewed until givent time + """ + + def _get_attestations_need_renewals_txn(txn): + sql = """ + SELECT group_id, user_id FROM group_attestations_renewals + WHERE valid_until_ms <= ? + """ + txn.execute(sql, (valid_until_ms,)) + return self.db_pool.cursor_to_dict(txn) + + return self.db_pool.runInteraction( + "get_attestations_need_renewals", _get_attestations_need_renewals_txn + ) + + @defer.inlineCallbacks + def get_remote_attestation(self, group_id, user_id): + """Get the attestation that proves the remote agrees that the user is + in the group. + """ + row = yield self.db_pool.simple_select_one( + table="group_attestations_remote", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcols=("valid_until_ms", "attestation_json"), + desc="get_remote_attestation", + allow_none=True, + ) + + now = int(self._clock.time_msec()) + if row and now < row["valid_until_ms"]: + return db_to_json(row["attestation_json"]) + + return None + + def get_joined_groups(self, user_id): + return self.db_pool.simple_select_onecol( + table="local_group_membership", + keyvalues={"user_id": user_id, "membership": "join"}, + retcol="group_id", + desc="get_joined_groups", + ) + + def get_all_groups_for_user(self, user_id, now_token): + def _get_all_groups_for_user_txn(txn): + sql = """ + SELECT group_id, type, membership, u.content + FROM local_group_updates AS u + INNER JOIN local_group_membership USING (group_id, user_id) + WHERE user_id = ? AND membership != 'leave' + AND stream_id <= ? + """ + txn.execute(sql, (user_id, now_token)) + return [ + { + "group_id": row[0], + "type": row[1], + "membership": row[2], + "content": db_to_json(row[3]), + } + for row in txn + ] + + return self.db_pool.runInteraction( + "get_all_groups_for_user", _get_all_groups_for_user_txn + ) + + def get_groups_changes_for_user(self, user_id, from_token, to_token): + from_token = int(from_token) + has_changed = self._group_updates_stream_cache.has_entity_changed( + user_id, from_token + ) + if not has_changed: + return defer.succeed([]) + + def _get_groups_changes_for_user_txn(txn): + sql = """ + SELECT group_id, membership, type, u.content + FROM local_group_updates AS u + INNER JOIN local_group_membership USING (group_id, user_id) + WHERE user_id = ? AND ? < stream_id AND stream_id <= ? + """ + txn.execute(sql, (user_id, from_token, to_token)) + return [ + { + "group_id": group_id, + "membership": membership, + "type": gtype, + "content": db_to_json(content_json), + } + for group_id, membership, gtype, content_json in txn + ] + + return self.db_pool.runInteraction( + "get_groups_changes_for_user", _get_groups_changes_for_user_txn + ) + + async def get_all_groups_changes( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for groups replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + last_id = int(last_id) + has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) + + if not has_changed: + return [], current_id, False + + def _get_all_groups_changes_txn(txn): + sql = """ + SELECT stream_id, group_id, user_id, type, content + FROM local_group_updates + WHERE ? < stream_id AND stream_id <= ? + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) + updates = [ + (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) + for stream_id, group_id, user_id, gtype, content_json in txn + ] + + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db_pool.runInteraction( + "get_all_groups_changes", _get_all_groups_changes_txn + ) + + +class GroupServerStore(GroupServerWorkerStore): + def set_group_join_policy(self, group_id, join_policy): + """Set the join policy of a group. + + join_policy can be one of: + * "invite" + * "open" + """ + return self.db_pool.simple_update_one( + table="groups", + keyvalues={"group_id": group_id}, + updatevalues={"join_policy": join_policy}, + desc="set_group_join_policy", + ) + + def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): + return self.db_pool.runInteraction( + "add_room_to_summary", + self._add_room_to_summary_txn, + group_id, + room_id, + category_id, + order, + is_public, + ) + + def _add_room_to_summary_txn( + self, txn, group_id, room_id, category_id, order, is_public + ): + """Add (or update) room's entry in summary. + + Args: + group_id (str) + room_id (str) + category_id (str): If not None then adds the category to the end of + the summary if its not already there. [Optional] + order (int): If not None inserts the room at that position, e.g. + an order of 1 will put the room first. Otherwise, the room gets + added to the end. + """ + room_in_group = self.db_pool.simple_select_one_onecol_txn( + txn, + table="group_rooms", + keyvalues={"group_id": group_id, "room_id": room_id}, + retcol="room_id", + allow_none=True, + ) + if not room_in_group: + raise SynapseError(400, "room not in group") + + if category_id is None: + category_id = _DEFAULT_CATEGORY_ID + else: + cat_exists = self.db_pool.simple_select_one_onecol_txn( + txn, + table="group_room_categories", + keyvalues={"group_id": group_id, "category_id": category_id}, + retcol="group_id", + allow_none=True, + ) + if not cat_exists: + raise SynapseError(400, "Category doesn't exist") + + # TODO: Check category is part of summary already + cat_exists = self.db_pool.simple_select_one_onecol_txn( + txn, + table="group_summary_room_categories", + keyvalues={"group_id": group_id, "category_id": category_id}, + retcol="group_id", + allow_none=True, + ) + if not cat_exists: + # If not, add it with an order larger than all others + txn.execute( + """ + INSERT INTO group_summary_room_categories + (group_id, category_id, cat_order) + SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1 + FROM group_summary_room_categories + WHERE group_id = ? AND category_id = ? + """, + (group_id, category_id, group_id, category_id), + ) + + existing = self.db_pool.simple_select_one_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + "category_id": category_id, + }, + retcols=("room_order", "is_public"), + allow_none=True, + ) + + if order is not None: + # Shuffle other room orders that come after the given order + sql = """ + UPDATE group_summary_rooms SET room_order = room_order + 1 + WHERE group_id = ? AND category_id = ? AND room_order >= ? + """ + txn.execute(sql, (group_id, category_id, order)) + elif not existing: + sql = """ + SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms + WHERE group_id = ? AND category_id = ? + """ + txn.execute(sql, (group_id, category_id)) + (order,) = txn.fetchone() + + if existing: + to_update = {} + if order is not None: + to_update["room_order"] = order + if is_public is not None: + to_update["is_public"] = is_public + self.db_pool.simple_update_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + }, + values=to_update, + ) + else: + if is_public is None: + is_public = True + + self.db_pool.simple_insert_txn( + txn, + table="group_summary_rooms", + values={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + "room_order": order, + "is_public": is_public, + }, + ) + + def remove_room_from_summary(self, group_id, room_id, category_id): + if category_id is None: + category_id = _DEFAULT_CATEGORY_ID + + return self.db_pool.simple_delete( + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + }, + desc="remove_room_from_summary", + ) + + def upsert_group_category(self, group_id, category_id, profile, is_public): + """Add/update room category for group + """ + insertion_values = {} + update_values = {"category_id": category_id} # This cannot be empty + + if profile is None: + insertion_values["profile"] = "{}" + else: + update_values["profile"] = json.dumps(profile) + + if is_public is None: + insertion_values["is_public"] = True + else: + update_values["is_public"] = is_public + + return self.db_pool.simple_upsert( + table="group_room_categories", + keyvalues={"group_id": group_id, "category_id": category_id}, + values=update_values, + insertion_values=insertion_values, + desc="upsert_group_category", + ) + + def remove_group_category(self, group_id, category_id): + return self.db_pool.simple_delete( + table="group_room_categories", + keyvalues={"group_id": group_id, "category_id": category_id}, + desc="remove_group_category", + ) + + def upsert_group_role(self, group_id, role_id, profile, is_public): + """Add/remove user role + """ + insertion_values = {} + update_values = {"role_id": role_id} # This cannot be empty + + if profile is None: + insertion_values["profile"] = "{}" + else: + update_values["profile"] = json.dumps(profile) + + if is_public is None: + insertion_values["is_public"] = True + else: + update_values["is_public"] = is_public + + return self.db_pool.simple_upsert( + table="group_roles", + keyvalues={"group_id": group_id, "role_id": role_id}, + values=update_values, + insertion_values=insertion_values, + desc="upsert_group_role", + ) + + def remove_group_role(self, group_id, role_id): + return self.db_pool.simple_delete( + table="group_roles", + keyvalues={"group_id": group_id, "role_id": role_id}, + desc="remove_group_role", + ) + + def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): + return self.db_pool.runInteraction( + "add_user_to_summary", + self._add_user_to_summary_txn, + group_id, + user_id, + role_id, + order, + is_public, + ) + + def _add_user_to_summary_txn( + self, txn, group_id, user_id, role_id, order, is_public + ): + """Add (or update) user's entry in summary. + + Args: + group_id (str) + user_id (str) + role_id (str): If not None then adds the role to the end of + the summary if its not already there. [Optional] + order (int): If not None inserts the user at that position, e.g. + an order of 1 will put the user first. Otherwise, the user gets + added to the end. + """ + user_in_group = self.db_pool.simple_select_one_onecol_txn( + txn, + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="user_id", + allow_none=True, + ) + if not user_in_group: + raise SynapseError(400, "user not in group") + + if role_id is None: + role_id = _DEFAULT_ROLE_ID + else: + role_exists = self.db_pool.simple_select_one_onecol_txn( + txn, + table="group_roles", + keyvalues={"group_id": group_id, "role_id": role_id}, + retcol="group_id", + allow_none=True, + ) + if not role_exists: + raise SynapseError(400, "Role doesn't exist") + + # TODO: Check role is part of the summary already + role_exists = self.db_pool.simple_select_one_onecol_txn( + txn, + table="group_summary_roles", + keyvalues={"group_id": group_id, "role_id": role_id}, + retcol="group_id", + allow_none=True, + ) + if not role_exists: + # If not, add it with an order larger than all others + txn.execute( + """ + INSERT INTO group_summary_roles + (group_id, role_id, role_order) + SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1 + FROM group_summary_roles + WHERE group_id = ? AND role_id = ? + """, + (group_id, role_id, group_id, role_id), + ) + + existing = self.db_pool.simple_select_one_txn( + txn, + table="group_summary_users", + keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, + retcols=("user_order", "is_public"), + allow_none=True, + ) + + if order is not None: + # Shuffle other users orders that come after the given order + sql = """ + UPDATE group_summary_users SET user_order = user_order + 1 + WHERE group_id = ? AND role_id = ? AND user_order >= ? + """ + txn.execute(sql, (group_id, role_id, order)) + elif not existing: + sql = """ + SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users + WHERE group_id = ? AND role_id = ? + """ + txn.execute(sql, (group_id, role_id)) + (order,) = txn.fetchone() + + if existing: + to_update = {} + if order is not None: + to_update["user_order"] = order + if is_public is not None: + to_update["is_public"] = is_public + self.db_pool.simple_update_txn( + txn, + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + }, + values=to_update, + ) + else: + if is_public is None: + is_public = True + + self.db_pool.simple_insert_txn( + txn, + table="group_summary_users", + values={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + "user_order": order, + "is_public": is_public, + }, + ) + + def remove_user_from_summary(self, group_id, user_id, role_id): + if role_id is None: + role_id = _DEFAULT_ROLE_ID + + return self.db_pool.simple_delete( + table="group_summary_users", + keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, + desc="remove_user_from_summary", + ) + + def add_group_invite(self, group_id, user_id): + """Record that the group server has invited a user + """ + return self.db_pool.simple_insert( + table="group_invites", + values={"group_id": group_id, "user_id": user_id}, + desc="add_group_invite", + ) + + def add_user_to_group( + self, + group_id, + user_id, + is_admin=False, + is_public=True, + local_attestation=None, + remote_attestation=None, + ): + """Add a user to the group server. + + Args: + group_id (str) + user_id (str) + is_admin (bool) + is_public (bool) + local_attestation (dict): The attestation the GS created to give + to the remote server. Optional if the user and group are on the + same server + remote_attestation (dict): The attestation given to GS by remote + server. Optional if the user and group are on the same server + """ + + def _add_user_to_group_txn(txn): + self.db_pool.simple_insert_txn( + txn, + table="group_users", + values={ + "group_id": group_id, + "user_id": user_id, + "is_admin": is_admin, + "is_public": is_public, + }, + ) + + self.db_pool.simple_delete_txn( + txn, + table="group_invites", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + + if local_attestation: + self.db_pool.simple_insert_txn( + txn, + table="group_attestations_renewals", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": local_attestation["valid_until_ms"], + }, + ) + if remote_attestation: + self.db_pool.simple_insert_txn( + txn, + table="group_attestations_remote", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": remote_attestation["valid_until_ms"], + "attestation_json": json.dumps(remote_attestation), + }, + ) + + return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) + + def remove_user_from_group(self, group_id, user_id): + def _remove_user_from_group_txn(txn): + self.db_pool.simple_delete_txn( + txn, + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="group_invites", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="group_attestations_renewals", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="group_attestations_remote", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="group_summary_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + + return self.db_pool.runInteraction( + "remove_user_from_group", _remove_user_from_group_txn + ) + + def add_room_to_group(self, group_id, room_id, is_public): + return self.db_pool.simple_insert( + table="group_rooms", + values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, + desc="add_room_to_group", + ) + + def update_room_in_group_visibility(self, group_id, room_id, is_public): + return self.db_pool.simple_update( + table="group_rooms", + keyvalues={"group_id": group_id, "room_id": room_id}, + updatevalues={"is_public": is_public}, + desc="update_room_in_group_visibility", + ) + + def remove_room_from_group(self, group_id, room_id): + def _remove_room_from_group_txn(txn): + self.db_pool.simple_delete_txn( + txn, + table="group_rooms", + keyvalues={"group_id": group_id, "room_id": room_id}, + ) + + self.db_pool.simple_delete_txn( + txn, + table="group_summary_rooms", + keyvalues={"group_id": group_id, "room_id": room_id}, + ) + + return self.db_pool.runInteraction( + "remove_room_from_group", _remove_room_from_group_txn + ) + + def update_group_publicity(self, group_id, user_id, publicise): + """Update whether the user is publicising their membership of the group + """ + return self.db_pool.simple_update_one( + table="local_group_membership", + keyvalues={"group_id": group_id, "user_id": user_id}, + updatevalues={"is_publicised": publicise}, + desc="update_group_publicity", + ) + + @defer.inlineCallbacks + def register_user_group_membership( + self, + group_id, + user_id, + membership, + is_admin=False, + content={}, + local_attestation=None, + remote_attestation=None, + is_publicised=False, + ): + """Registers that a local user is a member of a (local or remote) group. + + Args: + group_id (str) + user_id (str) + membership (str) + is_admin (bool) + content (dict): Content of the membership, e.g. includes the inviter + if the user has been invited. + local_attestation (dict): If remote group then store the fact that we + have given out an attestation, else None. + remote_attestation (dict): If remote group then store the remote + attestation from the group, else None. + """ + + def _register_user_group_membership_txn(txn, next_id): + # TODO: Upsert? + self.db_pool.simple_delete_txn( + txn, + table="local_group_membership", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self.db_pool.simple_insert_txn( + txn, + table="local_group_membership", + values={ + "group_id": group_id, + "user_id": user_id, + "is_admin": is_admin, + "membership": membership, + "is_publicised": is_publicised, + "content": json.dumps(content), + }, + ) + + self.db_pool.simple_insert_txn( + txn, + table="local_group_updates", + values={ + "stream_id": next_id, + "group_id": group_id, + "user_id": user_id, + "type": "membership", + "content": json.dumps( + {"membership": membership, "content": content} + ), + }, + ) + self._group_updates_stream_cache.entity_has_changed(user_id, next_id) + + # TODO: Insert profile to ensure it comes down stream if its a join. + + if membership == "join": + if local_attestation: + self.db_pool.simple_insert_txn( + txn, + table="group_attestations_renewals", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": local_attestation["valid_until_ms"], + }, + ) + if remote_attestation: + self.db_pool.simple_insert_txn( + txn, + table="group_attestations_remote", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": remote_attestation["valid_until_ms"], + "attestation_json": json.dumps(remote_attestation), + }, + ) + else: + self.db_pool.simple_delete_txn( + txn, + table="group_attestations_renewals", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="group_attestations_remote", + keyvalues={"group_id": group_id, "user_id": user_id}, + ) + + return next_id + + with self._group_updates_id_gen.get_next() as next_id: + res = yield self.db_pool.runInteraction( + "register_user_group_membership", + _register_user_group_membership_txn, + next_id, + ) + return res + + @defer.inlineCallbacks + def create_group( + self, group_id, user_id, name, avatar_url, short_description, long_description + ): + yield self.db_pool.simple_insert( + table="groups", + values={ + "group_id": group_id, + "name": name, + "avatar_url": avatar_url, + "short_description": short_description, + "long_description": long_description, + "is_public": True, + }, + desc="create_group", + ) + + @defer.inlineCallbacks + def update_group_profile(self, group_id, profile): + yield self.db_pool.simple_update_one( + table="groups", + keyvalues={"group_id": group_id}, + updatevalues=profile, + desc="update_group_profile", + ) + + def update_attestation_renewal(self, group_id, user_id, attestation): + """Update an attestation that we have renewed + """ + return self.db_pool.simple_update_one( + table="group_attestations_renewals", + keyvalues={"group_id": group_id, "user_id": user_id}, + updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, + desc="update_attestation_renewal", + ) + + def update_remote_attestion(self, group_id, user_id, attestation): + """Update an attestation that a remote has renewed + """ + return self.db_pool.simple_update_one( + table="group_attestations_remote", + keyvalues={"group_id": group_id, "user_id": user_id}, + updatevalues={ + "valid_until_ms": attestation["valid_until_ms"], + "attestation_json": json.dumps(attestation), + }, + desc="update_remote_attestion", + ) + + def remove_attestation_renewal(self, group_id, user_id): + """Remove an attestation that we thought we should renew, but actually + shouldn't. Ideally this would never get called as we would never + incorrectly try and do attestations for local users on local groups. + + Args: + group_id (str) + user_id (str) + """ + return self.db_pool.simple_delete( + table="group_attestations_renewals", + keyvalues={"group_id": group_id, "user_id": user_id}, + desc="remove_attestation_renewal", + ) + + def get_group_stream_token(self): + return self._group_updates_id_gen.get_current_token() + + def delete_group(self, group_id): + """Deletes a group fully from the database. + + Args: + group_id (str) + + Returns: + Deferred + """ + + def _delete_group_txn(txn): + tables = [ + "groups", + "group_users", + "group_invites", + "group_rooms", + "group_summary_rooms", + "group_summary_room_categories", + "group_room_categories", + "group_summary_users", + "group_summary_roles", + "group_roles", + "group_attestations_renewals", + "group_attestations_remote", + ] + + for table in tables: + self.db_pool.simple_delete_txn( + txn, table=table, keyvalues={"group_id": group_id} + ) + + return self.db_pool.runInteraction("delete_group", _delete_group_txn) -- cgit 1.5.1 From 4dd27e6d1125df83a754b5e0c2c14aaafc0ce837 Mon Sep 17 00:00:00 2001 From: David Vo Date: Fri, 7 Aug 2020 22:02:55 +1000 Subject: Reduce unnecessary whitespace in JSON. (#7372) --- changelog.d/7372.misc | 1 + synapse/http/server.py | 5 +++-- synapse/replication/tcp/commands.py | 5 +++-- synapse/rest/media/v1/preview_url_resource.py | 4 ++-- synapse/storage/databases/main/account_data.py | 7 +++---- synapse/storage/databases/main/deviceinbox.py | 9 ++++----- synapse/storage/databases/main/devices.py | 11 +++++------ synapse/storage/databases/main/e2e_room_keys.py | 11 +++++------ synapse/storage/databases/main/end_to_end_keys.py | 5 +++-- synapse/storage/databases/main/event_push_actions.py | 5 ++--- synapse/storage/databases/main/group_server.py | 17 ++++++++--------- synapse/storage/databases/main/push_rule.py | 9 ++++----- synapse/storage/databases/main/receipts.py | 9 ++++----- synapse/util/__init__.py | 4 ++++ synapse/util/frozenutils.py | 7 +++++-- 15 files changed, 56 insertions(+), 53 deletions(-) create mode 100644 changelog.d/7372.misc (limited to 'synapse/storage/databases/main/group_server.py') diff --git a/changelog.d/7372.misc b/changelog.d/7372.misc new file mode 100644 index 0000000000..67a39f0471 --- /dev/null +++ b/changelog.d/7372.misc @@ -0,0 +1 @@ +Reduce the amount of whitespace in JSON stored and sent in responses. Contributed by David Vo. diff --git a/synapse/http/server.py b/synapse/http/server.py index 94ab29974a..ffe6cfa09e 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -25,7 +25,7 @@ from io import BytesIO from typing import Any, Callable, Dict, Tuple, Union import jinja2 -from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json +from canonicaljson import encode_canonical_json, encode_pretty_printed_json from twisted.internet import defer from twisted.python import failure @@ -46,6 +46,7 @@ from synapse.api.errors import ( from synapse.http.site import SynapseRequest from synapse.logging.context import preserve_fn from synapse.logging.opentracing import trace_servlet +from synapse.util import json_encoder from synapse.util.caches import intern_dict logger = logging.getLogger(__name__) @@ -538,7 +539,7 @@ def respond_with_json( # canonicaljson already encodes to bytes json_bytes = encode_canonical_json(json_object) else: - json_bytes = json.dumps(json_object).encode("utf-8") + json_bytes = json_encoder.encode(json_object).encode("utf-8") return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors) diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index f33801f883..d853e4447e 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -18,11 +18,12 @@ The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are allowed to be sent by which side. """ import abc -import json import logging from typing import Tuple, Type -_json_encoder = json.JSONEncoder() +from canonicaljson import json + +from synapse.util import json_encoder as _json_encoder logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index f4768a9e8b..4bb454c36f 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -27,7 +27,6 @@ from typing import Dict, Optional from urllib import parse as urlparse import attr -from canonicaljson import json from twisted.internet import defer from twisted.internet.error import DNSLookupError @@ -43,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers +from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.stringutils import random_string @@ -355,7 +355,7 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("Calculated OG for %s as %s", url, og) - jsonog = json.dumps(og) + jsonog = json_encoder.encode(og) # store OG in history-aware DB cache await self.store.store_url_cache( diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 2193d8fdc5..cf039e7f7d 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -18,13 +18,12 @@ import abc import logging from typing import List, Tuple -from canonicaljson import json - from twisted.internet import defer from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -327,7 +326,7 @@ class AccountDataStore(AccountDataWorkerStore): Returns: A deferred that completes once the account_data has been added. """ - content_json = json.dumps(content) + content_json = json_encoder.encode(content) with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint @@ -373,7 +372,7 @@ class AccountDataStore(AccountDataWorkerStore): Returns: A deferred that completes once the account_data has been added. """ - content_json = json.dumps(content) + content_json = json_encoder.encode(content) with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 874ecdf8d2..76ec954f44 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -16,13 +16,12 @@ import logging from typing import List, Tuple -from canonicaljson import json - from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool +from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache logger = logging.getLogger(__name__) @@ -354,7 +353,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) ) rows = [] for destination, edu in remote_messages_by_destination.items(): - edu_json = json.dumps(edu) + edu_json = json_encoder.encode(edu) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) @@ -432,7 +431,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Handle wildcard device_ids. sql = "SELECT device_id FROM devices WHERE user_id = ?" txn.execute(sql, (user_id,)) - message_json = json.dumps(messages_by_device["*"]) + message_json = json_encoder.encode(messages_by_device["*"]) for row in txn: # Add the message for all devices for this user on this # server. @@ -454,7 +453,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Only insert into the local inbox if the device exists on # this server device = row[0] - message_json = json.dumps(messages_by_device[device]) + message_json = json_encoder.encode(messages_by_device[device]) messages_json_for_user[device] = message_json if messages_json_for_user: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 88a7aadfc6..81e64de126 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -17,8 +17,6 @@ import logging from typing import List, Optional, Set, Tuple -from canonicaljson import json - from twisted.internet import defer from synapse.api.errors import Codes, StoreError @@ -36,6 +34,7 @@ from synapse.storage.database import ( make_tuple_comparison_clause, ) from synapse.types import Collection, get_verify_key_from_cross_signing_key +from synapse.util import json_encoder from synapse.util.caches.descriptors import ( Cache, cached, @@ -397,7 +396,7 @@ class DeviceWorkerStore(SQLBaseStore): values={ "stream_id": stream_id, "from_user_id": from_user_id, - "user_ids": json.dumps(user_ids), + "user_ids": json_encoder.encode(user_ids), }, ) @@ -1032,7 +1031,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, - values={"content": json.dumps(content)}, + values={"content": json_encoder.encode(content)}, # we don't need to lock, because we assume we are the only thread # updating this user's devices. lock=False, @@ -1088,7 +1087,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): { "user_id": user_id, "device_id": content["device_id"], - "content": json.dumps(content), + "content": json_encoder.encode(content), } for content in devices ], @@ -1209,7 +1208,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): "device_id": device_id, "sent": False, "ts": now, - "opentracing_context": json.dumps(context) + "opentracing_context": json_encoder.encode(context) if whitelisted_homeserver(destination) else "{}", } diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 90152edc3c..c4aaec3993 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -14,13 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from canonicaljson import json - from twisted.internet import defer from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.util import json_encoder class EndToEndRoomKeyStore(SQLBaseStore): @@ -50,7 +49,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], - "session_data": json.dumps(room_key["session_data"]), + "session_data": json_encoder.encode(room_key["session_data"]), }, desc="update_e2e_room_key", ) @@ -77,7 +76,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], - "session_data": json.dumps(room_key["session_data"]), + "session_data": json_encoder.encode(room_key["session_data"]), } ) log_kv( @@ -360,7 +359,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "user_id": user_id, "version": new_version, "algorithm": info["algorithm"], - "auth_data": json.dumps(info["auth_data"]), + "auth_data": json_encoder.encode(info["auth_data"]), }, ) @@ -387,7 +386,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): updatevalues = {} if info is not None and "auth_data" in info: - updatevalues["auth_data"] = json.dumps(info["auth_data"]) + updatevalues["auth_data"] = json_encoder.encode(info["auth_data"]) if version_etag is not None: updatevalues["etag"] = version_etag diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 40354b8304..6126376a6f 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -16,7 +16,7 @@ # limitations under the License. from typing import Dict, List, Tuple -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.enterprise.adbapi import Connection from twisted.internet import defer @@ -24,6 +24,7 @@ from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import make_in_list_sql_clause +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -700,7 +701,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): values={ "user_id": user_id, "keytype": key_type, - "keydata": json.dumps(key), + "keydata": json_encoder.encode(key), "stream_id": stream_id, }, ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index b8cefb4d5e..7c246d3e4c 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -17,11 +17,10 @@ import logging from typing import List -from canonicaljson import json - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool +from synapse.util import json_encoder from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -50,7 +49,7 @@ def _serialize_action(actions, is_highlight): else: if actions == DEFAULT_NOTIF_ACTION: return "" - return json.dumps(actions) + return json_encoder.encode(actions) def _deserialize_action(actions, is_highlight): diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index a98181f445..75ea6d4b2f 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -16,12 +16,11 @@ from typing import List, Tuple -from canonicaljson import json - from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.util import json_encoder # The category ID for the "default" category. We don't store as null in the # database to avoid the fun of null != null @@ -752,7 +751,7 @@ class GroupServerStore(GroupServerWorkerStore): if profile is None: insertion_values["profile"] = "{}" else: - update_values["profile"] = json.dumps(profile) + update_values["profile"] = json_encoder.encode(profile) if is_public is None: insertion_values["is_public"] = True @@ -783,7 +782,7 @@ class GroupServerStore(GroupServerWorkerStore): if profile is None: insertion_values["profile"] = "{}" else: - update_values["profile"] = json.dumps(profile) + update_values["profile"] = json_encoder.encode(profile) if is_public is None: insertion_values["is_public"] = True @@ -1007,7 +1006,7 @@ class GroupServerStore(GroupServerWorkerStore): "group_id": group_id, "user_id": user_id, "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json.dumps(remote_attestation), + "attestation_json": json_encoder.encode(remote_attestation), }, ) @@ -1131,7 +1130,7 @@ class GroupServerStore(GroupServerWorkerStore): "is_admin": is_admin, "membership": membership, "is_publicised": is_publicised, - "content": json.dumps(content), + "content": json_encoder.encode(content), }, ) @@ -1143,7 +1142,7 @@ class GroupServerStore(GroupServerWorkerStore): "group_id": group_id, "user_id": user_id, "type": "membership", - "content": json.dumps( + "content": json_encoder.encode( {"membership": membership, "content": content} ), }, @@ -1171,7 +1170,7 @@ class GroupServerStore(GroupServerWorkerStore): "group_id": group_id, "user_id": user_id, "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json.dumps(remote_attestation), + "attestation_json": json_encoder.encode(remote_attestation), }, ) else: @@ -1240,7 +1239,7 @@ class GroupServerStore(GroupServerWorkerStore): keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ "valid_until_ms": attestation["valid_until_ms"], - "attestation_json": json.dumps(attestation), + "attestation_json": json_encoder.encode(attestation), }, desc="update_remote_attestion", ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 97cc12931d..264521635f 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -18,8 +18,6 @@ import abc import logging from typing import List, Tuple, Union -from canonicaljson import json - from twisted.internet import defer from synapse.push.baserules import list_with_base_rules @@ -33,6 +31,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.util.id_generators import ChainedIdGenerator +from synapse.util import json_encoder from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -411,8 +410,8 @@ class PushRuleStore(PushRulesWorkerStore): before=None, after=None, ): - conditions_json = json.dumps(conditions) - actions_json = json.dumps(actions) + conditions_json = json_encoder.encode(conditions) + actions_json = json_encoder.encode(actions) with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids if before or after: @@ -681,7 +680,7 @@ class PushRuleStore(PushRulesWorkerStore): @defer.inlineCallbacks def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): - actions_json = json.dumps(actions) + actions_json = json_encoder.encode(actions) def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): if is_default_rule: diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 6255977c92..1920a8a152 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -18,13 +18,12 @@ import abc import logging from typing import List, Tuple -from canonicaljson import json - from twisted.internet import defer from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -459,7 +458,7 @@ class ReceiptsStore(ReceiptsWorkerStore): values={ "stream_id": stream_id, "event_id": event_id, - "data": json.dumps(data), + "data": json_encoder.encode(data), }, # receipts_linearized has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock @@ -585,7 +584,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, - "event_ids": json.dumps(event_ids), - "data": json.dumps(data), + "event_ids": json_encoder.encode(event_ids), + "data": json_encoder.encode(data), }, ) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index c63256d3bd..b3f76428b6 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -17,6 +17,7 @@ import logging import re import attr +from canonicaljson import json from twisted.internet import defer, task @@ -24,6 +25,9 @@ from synapse.logging import context logger = logging.getLogger(__name__) +# Create a custom encoder to reduce the whitespace produced by JSON encoding. +json_encoder = json.JSONEncoder(separators=(",", ":")) + def unwrapFirstError(failure): # defer.gatherResults and DeferredLists wrap failures. diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index eab78dd256..0e445e01d7 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -63,5 +63,8 @@ def _handle_frozendict(obj): ) -# A JSONEncoder which is capable of encoding frozendicts without barfing -frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict) +# A JSONEncoder which is capable of encoding frozendicts without barfing. +# Additionally reduce the whitespace produced by JSON encoding. +frozendict_json_encoder = json.JSONEncoder( + default=_handle_frozendict, separators=(",", ":"), +) -- cgit 1.5.1 From a3a59bab7bb3b69dcfc5620e6f3ac51af3f0f965 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Aug 2020 09:28:48 -0400 Subject: Convert appservice, group server, profile and more databases to async (#8066) --- changelog.d/8066.misc | 1 + synapse/storage/databases/main/appservice.py | 34 ++++------ synapse/storage/databases/main/filtering.py | 8 +-- synapse/storage/databases/main/group_server.py | 86 ++++++++++++-------------- synapse/storage/databases/main/presence.py | 7 +-- synapse/storage/databases/main/profile.py | 21 +++---- synapse/storage/databases/main/relations.py | 19 +++--- synapse/storage/databases/main/transactions.py | 7 +-- tests/storage/test_appservice.py | 24 +++---- 9 files changed, 91 insertions(+), 116 deletions(-) create mode 100644 changelog.d/8066.misc (limited to 'synapse/storage/databases/main/group_server.py') diff --git a/changelog.d/8066.misc b/changelog.d/8066.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8066.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 055a3962dc..5cf1a88399 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -18,8 +18,6 @@ import re from canonicaljson import json -from twisted.internet import defer - from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices from synapse.storage._base import SQLBaseStore, db_to_json @@ -124,17 +122,15 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore): class ApplicationServiceTransactionWorkerStore( ApplicationServiceWorkerStore, EventsWorkerStore ): - @defer.inlineCallbacks - def get_appservices_by_state(self, state): + async def get_appservices_by_state(self, state): """Get a list of application services based on their state. Args: state(ApplicationServiceState): The state to filter on. Returns: - A Deferred which resolves to a list of ApplicationServices, which - may be empty. + A list of ApplicationServices, which may be empty. """ - results = yield self.db_pool.simple_select_list( + results = await self.db_pool.simple_select_list( "application_services_state", {"state": state}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore @@ -147,16 +143,15 @@ class ApplicationServiceTransactionWorkerStore( services.append(service) return services - @defer.inlineCallbacks - def get_appservice_state(self, service): + async def get_appservice_state(self, service): """Get the application service state. Args: service(ApplicationService): The service whose state to set. Returns: - A Deferred which resolves to ApplicationServiceState. + An ApplicationServiceState. """ - result = yield self.db_pool.simple_select_one( + result = await self.db_pool.simple_select_one( "application_services_state", {"as_id": service.id}, ["state"], @@ -270,16 +265,14 @@ class ApplicationServiceTransactionWorkerStore( "complete_appservice_txn", _complete_appservice_txn ) - @defer.inlineCallbacks - def get_oldest_unsent_txn(self, service): + async def get_oldest_unsent_txn(self, service): """Get the oldest transaction which has not been sent for this service. Args: service(ApplicationService): The app service to get the oldest txn. Returns: - A Deferred which resolves to an AppServiceTransaction or - None. + An AppServiceTransaction or None. """ def _get_oldest_unsent_txn(txn): @@ -298,7 +291,7 @@ class ApplicationServiceTransactionWorkerStore( return entry - entry = yield self.db_pool.runInteraction( + entry = await self.db_pool.runInteraction( "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn ) @@ -307,7 +300,7 @@ class ApplicationServiceTransactionWorkerStore( event_ids = db_to_json(entry["event_ids"]) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) @@ -332,8 +325,7 @@ class ApplicationServiceTransactionWorkerStore( "set_appservice_last_pos", set_appservice_last_pos_txn ) - @defer.inlineCallbacks - def get_new_events_for_appservice(self, current_id, limit): + async def get_new_events_for_appservice(self, current_id, limit): """Get all new evnets""" def get_new_events_for_appservice_txn(txn): @@ -357,11 +349,11 @@ class ApplicationServiceTransactionWorkerStore( return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.db_pool.runInteraction( + upper_bound, event_ids = await self.db_pool.runInteraction( "get_new_events_for_appservice", get_new_events_for_appservice_txn ) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return upper_bound, events diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index cae6bda80e..45a1760170 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -17,12 +17,12 @@ from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached class FilteringStore(SQLBaseStore): - @cachedInlineCallbacks(num_args=2) - def get_user_filter(self, user_localpart, filter_id): + @cached(num_args=2) + async def get_user_filter(self, user_localpart, filter_id): # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # with a coherent error message rather than 500 M_UNKNOWN. try: @@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore): except ValueError: raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - def_json = yield self.db_pool.simple_select_one_onecol( + def_json = await self.db_pool.simple_select_one_onecol( table="user_filters", keyvalues={"user_id": user_localpart, "filter_id": filter_id}, retcol="filter_json", diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 75ea6d4b2f..380db3a3f3 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,12 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple - -from twisted.internet import defer +from typing import List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.types import JsonDict from synapse.util import json_encoder # The category ID for the "default" category. We don't store as null in the @@ -210,9 +209,8 @@ class GroupServerWorkerStore(SQLBaseStore): "get_rooms_for_summary", _get_rooms_for_summary_txn ) - @defer.inlineCallbacks - def get_group_categories(self, group_id): - rows = yield self.db_pool.simple_select_list( + async def get_group_categories(self, group_id): + rows = await self.db_pool.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, retcols=("category_id", "is_public", "profile"), @@ -227,9 +225,8 @@ class GroupServerWorkerStore(SQLBaseStore): for row in rows } - @defer.inlineCallbacks - def get_group_category(self, group_id, category_id): - category = yield self.db_pool.simple_select_one( + async def get_group_category(self, group_id, category_id): + category = await self.db_pool.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, retcols=("is_public", "profile"), @@ -240,9 +237,8 @@ class GroupServerWorkerStore(SQLBaseStore): return category - @defer.inlineCallbacks - def get_group_roles(self, group_id): - rows = yield self.db_pool.simple_select_list( + async def get_group_roles(self, group_id): + rows = await self.db_pool.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, retcols=("role_id", "is_public", "profile"), @@ -257,9 +253,8 @@ class GroupServerWorkerStore(SQLBaseStore): for row in rows } - @defer.inlineCallbacks - def get_group_role(self, group_id, role_id): - role = yield self.db_pool.simple_select_one( + async def get_group_role(self, group_id, role_id): + role = await self.db_pool.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, retcols=("is_public", "profile"), @@ -448,12 +443,11 @@ class GroupServerWorkerStore(SQLBaseStore): "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) - @defer.inlineCallbacks - def get_remote_attestation(self, group_id, user_id): + async def get_remote_attestation(self, group_id, user_id): """Get the attestation that proves the remote agrees that the user is in the group. """ - row = yield self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, retcols=("valid_until_ms", "attestation_json"), @@ -499,13 +493,13 @@ class GroupServerWorkerStore(SQLBaseStore): "get_all_groups_for_user", _get_all_groups_for_user_txn ) - def get_groups_changes_for_user(self, user_id, from_token, to_token): + async def get_groups_changes_for_user(self, user_id, from_token, to_token): from_token = int(from_token) has_changed = self._group_updates_stream_cache.has_entity_changed( user_id, from_token ) if not has_changed: - return defer.succeed([]) + return [] def _get_groups_changes_for_user_txn(txn): sql = """ @@ -525,7 +519,7 @@ class GroupServerWorkerStore(SQLBaseStore): for group_id, membership, gtype, content_json in txn ] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_groups_changes_for_user", _get_groups_changes_for_user_txn ) @@ -1087,31 +1081,31 @@ class GroupServerStore(GroupServerWorkerStore): desc="update_group_publicity", ) - @defer.inlineCallbacks - def register_user_group_membership( + async def register_user_group_membership( self, - group_id, - user_id, - membership, - is_admin=False, - content={}, - local_attestation=None, - remote_attestation=None, - is_publicised=False, - ): + group_id: str, + user_id: str, + membership: str, + is_admin: bool = False, + content: JsonDict = {}, + local_attestation: Optional[dict] = None, + remote_attestation: Optional[dict] = None, + is_publicised: bool = False, + ) -> int: """Registers that a local user is a member of a (local or remote) group. Args: - group_id (str) - user_id (str) - membership (str) - is_admin (bool) - content (dict): Content of the membership, e.g. includes the inviter + group_id: The group the member is being added to. + user_id: THe user ID to add to the group. + membership: The type of group membership. + is_admin: Whether the user should be added as a group admin. + content: Content of the membership, e.g. includes the inviter if the user has been invited. - local_attestation (dict): If remote group then store the fact that we + local_attestation: If remote group then store the fact that we have given out an attestation, else None. - remote_attestation (dict): If remote group then store the remote + remote_attestation: If remote group then store the remote attestation from the group, else None. + is_publicised: Whether this should be publicised. """ def _register_user_group_membership_txn(txn, next_id): @@ -1188,18 +1182,17 @@ class GroupServerStore(GroupServerWorkerStore): return next_id with self._group_updates_id_gen.get_next() as next_id: - res = yield self.db_pool.runInteraction( + res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, next_id, ) return res - @defer.inlineCallbacks - def create_group( + async def create_group( self, group_id, user_id, name, avatar_url, short_description, long_description - ): - yield self.db_pool.simple_insert( + ) -> None: + await self.db_pool.simple_insert( table="groups", values={ "group_id": group_id, @@ -1212,9 +1205,8 @@ class GroupServerStore(GroupServerWorkerStore): desc="create_group", ) - @defer.inlineCallbacks - def update_group_profile(self, group_id, profile): - yield self.db_pool.simple_update_one( + async def update_group_profile(self, group_id, profile): + await self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues=profile, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 99e66dc6e9..59ba12820a 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -15,8 +15,6 @@ from typing import List, Tuple -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.presence import UserPresenceState from synapse.util.caches.descriptors import cached, cachedList @@ -24,14 +22,13 @@ from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): - @defer.inlineCallbacks - def update_presence(self, presence_states): + async def update_presence(self, presence_states): stream_ordering_manager = self._presence_id_gen.get_next_mult( len(presence_states) ) with stream_ordering_manager as stream_orderings: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_presence", self._update_presence_txn, stream_orderings, diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 4a4f2cb385..b8261357d4 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.databases.main.roommember import ProfileInfo class ProfileWorkerStore(SQLBaseStore): - @defer.inlineCallbacks - def get_profileinfo(self, user_localpart): + async def get_profileinfo(self, user_localpart): try: - profile = yield self.db_pool.simple_select_one( + profile = await self.db_pool.simple_select_one( table="profiles", keyvalues={"user_id": user_localpart}, retcols=("displayname", "avatar_url"), @@ -118,14 +115,13 @@ class ProfileStore(ProfileWorkerStore): desc="update_remote_profile_cache", ) - @defer.inlineCallbacks - def maybe_delete_remote_profile_cache(self, user_id): + async def maybe_delete_remote_profile_cache(self, user_id): """Check if we still care about the remote user's profile, and if we don't then remove their profile from the cache """ - subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) + subscribed = await self.is_subscribed_remote_profile_for_user(user_id) if not subscribed: - yield self.db_pool.simple_delete( + await self.db_pool.simple_delete( table="remote_profile_cache", keyvalues={"user_id": user_id}, desc="delete_remote_profile_cache", @@ -151,11 +147,10 @@ class ProfileStore(ProfileWorkerStore): _get_remote_profile_cache_entries_that_expire_txn, ) - @defer.inlineCallbacks - def is_subscribed_remote_profile_for_user(self, user_id): + async def is_subscribed_remote_profile_for_user(self, user_id): """Check whether we are interested in a remote user's profile. """ - res = yield self.db_pool.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -166,7 +161,7 @@ class ProfileStore(ProfileWorkerStore): if res: return True - res = yield self.db_pool.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="group_invites", keyvalues={"user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index b81f1449b7..a9ceffc20e 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -14,10 +14,12 @@ # limitations under the License. import logging +from typing import Optional import attr from synapse.api.constants import RelationTypes +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.relations import ( @@ -25,7 +27,7 @@ from synapse.storage.relations import ( PaginationChunk, RelationPaginationToken, ) -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -227,18 +229,18 @@ class RelationsWorkerStore(SQLBaseStore): "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) - @cachedInlineCallbacks() - def get_applicable_edit(self, event_id): + @cached() + async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: """Get the most recent edit (if any) that has happened for the given event. Correctly handles checking whether edits were allowed to happen. Args: - event_id (str): The original event ID + event_id: The original event ID Returns: - Deferred[EventBase|None]: Returns the most recent edit, if any. + The most recent edit, if any. """ # We only allow edits for `m.room.message` events that have the same sender @@ -268,15 +270,14 @@ class RelationsWorkerStore(SQLBaseStore): if row: return row[0] - edit_id = yield self.db_pool.runInteraction( + edit_id = await self.db_pool.runInteraction( "get_applicable_edit", _get_applicable_edit_txn ) if not edit_id: - return + return None - edit_event = yield self.get_event(edit_id, allow_none=True) - return edit_event + return await self.get_event(edit_id, allow_none=True) def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): """Check if a user has already annotated an event with the same key diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 8804c0e4ac..52668dbdf9 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -18,8 +18,6 @@ from collections import namedtuple from canonicaljson import encode_canonical_json -from twisted.internet import defer - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool @@ -126,8 +124,7 @@ class TransactionStore(SQLBaseStore): desc="set_received_txn_response", ) - @defer.inlineCallbacks - def get_destination_retry_timings(self, destination): + async def get_destination_retry_timings(self, destination): """Gets the current retry timings (if any) for a given destination. Args: @@ -142,7 +139,7 @@ class TransactionStore(SQLBaseStore): if result is not SENTINEL: return result - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination, diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 1b516b7976..98b74890d5 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -178,14 +178,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_appservice_state_none(self): service = Mock(id="999") - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(None, state) @defer.inlineCallbacks def test_get_appservice_state_up(self): yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) service = Mock(id=self.as_list[0]["id"]) - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.UP, state) @defer.inlineCallbacks @@ -194,13 +194,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) service = Mock(id=self.as_list[1]["id"]) - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.DOWN, state) @defer.inlineCallbacks def test_get_appservices_by_state_none(self): - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(0, len(services)) @@ -339,7 +339,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): def test_get_oldest_unsent_txn_none(self): service = Mock(id=self.as_list[0]["id"]) - txn = yield self.store.get_oldest_unsent_txn(service) + txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(None, txn) @defer.inlineCallbacks @@ -349,14 +349,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=events) + self.store.get_events_as_list = Mock(return_value=defer.succeed(events)) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) yield self._insert_txn(service.id, 11, other_events) yield self._insert_txn(service.id, 12, other_events) - txn = yield self.store.get_oldest_unsent_txn(service) + txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(service, txn.service) self.assertEquals(10, txn.id) self.assertEquals(events, txn.events) @@ -366,8 +366,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(1, len(services)) self.assertEquals(self.as_list[0]["id"], services[0].id) @@ -379,8 +379,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP) - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(2, len(services)) self.assertEquals( -- cgit 1.5.1 From 76c43f086a3c5feeab1fd0f916086930d44fb1cf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 06:39:55 -0400 Subject: Do not assume calls to runInteraction return Deferreds. (#8133) --- changelog.d/8133.misc | 1 + synapse/crypto/keyring.py | 7 +++--- synapse/module_api/__init__.py | 10 +++++--- synapse/spam_checker_api/__init__.py | 6 +++-- synapse/storage/databases/main/group_server.py | 7 +++--- synapse/storage/databases/main/keys.py | 28 ++++++++++++---------- .../storage/databases/main/user_erasure_store.py | 13 +++++----- 7 files changed, 41 insertions(+), 31 deletions(-) create mode 100644 changelog.d/8133.misc (limited to 'synapse/storage/databases/main/group_server.py') diff --git a/changelog.d/8133.misc b/changelog.d/8133.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8133.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 28ef7cfdb9..81c4b430b2 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -757,9 +757,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher): except Exception: logger.exception("Error getting keys %s from %s", key_ids, server_name) - return await yieldable_gather_results( - get_key, keys_to_fetch.items() - ).addCallback(lambda _: results) + await yieldable_gather_results(get_key, keys_to_fetch.items()) + return results async def get_server_verify_key_v2_direct(self, server_name, key_ids): """ @@ -769,7 +768,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): key_ids (iterable[str]): Returns: - Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result + dict[str, FetchKeyResult]: map from key ID to lookup result Raises: KeyLookupError if there was a problem making the lookup diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index c2fb757d9a..ae0e359a77 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -167,8 +167,10 @@ class ModuleApi(object): external_id: id on that system user_id: complete mxid that it is mapped to """ - return self._store.record_user_external_id( - auth_provider_id, remote_user_id, registered_user_id + return defer.ensureDeferred( + self._store.record_user_external_id( + auth_provider_id, remote_user_id, registered_user_id + ) ) def generate_short_term_login_token( @@ -223,7 +225,9 @@ class ModuleApi(object): Returns: Deferred[object]: result of func """ - return self._store.db_pool.runInteraction(desc, func, *args, **kwargs) + return defer.ensureDeferred( + self._store.db_pool.runInteraction(desc, func, *args, **kwargs) + ) def complete_sso_login( self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 4d9b13ac04..7f63f1bfa0 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -48,8 +48,10 @@ class SpamCheckerApi(object): twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]: The filtered state events in the room. """ - state_ids = yield self._store.get_filtered_current_state_ids( - room_id=room_id, state_filter=StateFilter.from_types(types) + state_ids = yield defer.ensureDeferred( + self._store.get_filtered_current_state_ids( + room_id=room_id, state_filter=StateFilter.from_types(types) + ) ) state = yield defer.ensureDeferred(self._store.get_events(state_ids.values())) return state.values() diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 380db3a3f3..0e3b8739c6 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -341,14 +341,15 @@ class GroupServerWorkerStore(SQLBaseStore): "get_users_for_summary_by_role", _get_users_for_summary_txn ) - def is_user_in_group(self, user_id, group_id): - return self.db_pool.simple_select_one_onecol( + async def is_user_in_group(self, user_id: str, group_id: str) -> bool: + result = await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", allow_none=True, desc="is_user_in_group", - ).addCallback(lambda r: bool(r)) + ) + return bool(result) def is_user_admin_in_group(self, group_id, user_id): return self.db_pool.simple_select_one_onecol( diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 384e9c5eb0..fadcad51e7 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,6 +16,7 @@ import itertools import logging +from typing import Iterable, Tuple from signedjson.key import decode_verify_key_bytes @@ -88,12 +89,17 @@ class KeyStore(SQLBaseStore): return self.db_pool.runInteraction("get_server_verify_keys", _txn) - def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): + async def store_server_verify_keys( + self, + from_server: str, + ts_added_ms: int, + verify_keys: Iterable[Tuple[str, str, FetchKeyResult]], + ) -> None: """Stores NACL verification keys for remote servers. Args: - from_server (str): Where the verification keys were looked up - ts_added_ms (int): The time to record that the key was added - verify_keys (iterable[tuple[str, str, FetchKeyResult]]): + from_server: Where the verification keys were looked up + ts_added_ms: The time to record that the key was added + verify_keys: keys to be stored. Each entry is a triplet of (server_name, key_id, key). """ @@ -115,13 +121,7 @@ class KeyStore(SQLBaseStore): # param, which is itself the 2-tuple (server_name, key_id). invalidations.append((server_name, key_id)) - def _invalidate(res): - f = self._get_server_verify_key.invalidate - for i in invalidations: - f((i,)) - return res - - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "store_server_verify_keys", self.db_pool.simple_upsert_many_txn, table="server_signature_keys", @@ -134,7 +134,11 @@ class KeyStore(SQLBaseStore): "verify_key", ), value_values=value_values, - ).addCallback(_invalidate) + ) + + invalidate = self._get_server_verify_key.invalidate + for i in invalidations: + invalidate((i,)) def store_server_keys_json( self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index da23fe7355..e3547e53b3 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -13,30 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import operator - from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedList class UserErasureWorkerStore(SQLBaseStore): @cached() - def is_user_erased(self, user_id): + async def is_user_erased(self, user_id: str) -> bool: """ Check if the given user id has requested erasure Args: - user_id (str): full user id to check + user_id: full user id to check Returns: - Deferred[bool]: True if the user has requested erasure + True if the user has requested erasure """ - return self.db_pool.simple_select_onecol( + result = await self.db_pool.simple_select_onecol( table="erased_users", keyvalues={"user_id": user_id}, retcol="1", desc="is_user_erased", - ).addCallback(operator.truth) + ) + return bool(result) @cachedList(cached_method_name="is_user_erased", list_name="user_ids") async def are_users_erased(self, user_ids): -- cgit 1.5.1 From 2231dffee6788836c86e868dd29574970b13dd18 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 25 Aug 2020 15:10:08 +0100 Subject: Make StreamIdGen `get_next` and `get_next_mult` async (#8161) This is mainly so that `StreamIdGenerator` and `MultiWriterIdGenerator` will have the same interface, allowing them to be used interchangeably. --- changelog.d/8161.misc | 1 + synapse/storage/databases/main/account_data.py | 4 +-- synapse/storage/databases/main/deviceinbox.py | 4 +-- synapse/storage/databases/main/devices.py | 8 +++-- synapse/storage/databases/main/end_to_end_keys.py | 43 ++++++++++++----------- synapse/storage/databases/main/events.py | 4 +-- synapse/storage/databases/main/group_server.py | 2 +- synapse/storage/databases/main/presence.py | 2 +- synapse/storage/databases/main/push_rule.py | 8 ++--- synapse/storage/databases/main/pusher.py | 4 +-- synapse/storage/databases/main/receipts.py | 3 +- synapse/storage/databases/main/room.py | 6 ++-- synapse/storage/databases/main/tags.py | 4 +-- synapse/storage/util/id_generators.py | 10 +++--- 14 files changed, 54 insertions(+), 49 deletions(-) create mode 100644 changelog.d/8161.misc (limited to 'synapse/storage/databases/main/group_server.py') diff --git a/changelog.d/8161.misc b/changelog.d/8161.misc new file mode 100644 index 0000000000..89ff274de3 --- /dev/null +++ b/changelog.d/8161.misc @@ -0,0 +1 @@ +Refactor `StreamIdGenerator` and `MultiWriterIdGenerator` to have the same interface. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 82aac2bbf3..04042a2c98 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. @@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 1f6e995c4f..bb85637a95 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) - with self._device_inbox_id_gen.get_next() as stream_id: + with await self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id @@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) txn, stream_id, local_messages_by_user_then_device ) - with self._device_inbox_id_gen.get_next() as stream_id: + with await self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 9a786e2929..03b45dbc4d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore): THe new stream ID. """ - with self._device_list_id_gen.get_next() as stream_id: + with await self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -1146,7 +1146,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return - with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: + with await self._device_list_id_gen.get_next_mult( + len(device_ids) + ) as stream_ids: await self.db_pool.runInteraction( "add_device_change_to_stream", self._add_device_change_to_stream_txn, @@ -1159,7 +1161,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return stream_ids[-1] context = get_active_span_text_map() - with self._device_list_id_gen.get_next_mult( + with await self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index f93e0d320d..385868bdab 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) - def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): + def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id): """Set a user's cross-signing key. Args: @@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key key (dict): the key data + stream_id (int) """ # the 'key' dict will look something like: # { @@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) # and finally, store the key itself - with self._cross_signing_id_gen.get_next() as stream_id: - self.db_pool.simple_insert_txn( - txn, - "e2e_cross_signing_keys", - values={ - "user_id": user_id, - "keytype": key_type, - "keydata": json_encoder.encode(key), - "stream_id": stream_id, - }, - ) + self.db_pool.simple_insert_txn( + txn, + "e2e_cross_signing_keys", + values={ + "user_id": user_id, + "keytype": key_type, + "keydata": json_encoder.encode(key), + "stream_id": stream_id, + }, + ) self._invalidate_cache_and_stream( txn, self._get_bare_e2e_cross_signing_keys, (user_id,) ) - def set_e2e_cross_signing_key(self, user_id, key_type, key): + async def set_e2e_cross_signing_key(self, user_id, key_type, key): """Set a user's cross-signing key. Args: @@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key_type (str): the type of cross-signing key to set key (dict): the key data """ - return self.db_pool.runInteraction( - "add_e2e_cross_signing_key", - self._set_e2e_cross_signing_key_txn, - user_id, - key_type, - key, - ) + + with await self._cross_signing_id_gen.get_next() as stream_id: + return await self.db_pool.runInteraction( + "add_e2e_cross_signing_key", + self._set_e2e_cross_signing_key_txn, + user_id, + key_type, + key, + stream_id, + ) def store_e2e_cross_signing_signatures(self, user_id, signatures): """Stores cross-signing signatures. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b90e6de2d5..6313b41eef 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -153,11 +153,11 @@ class PersistEventsStore: # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. if backfilled: - stream_ordering_manager = self._backfill_id_gen.get_next_mult( + stream_ordering_manager = await self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) else: - stream_ordering_manager = self._stream_id_gen.get_next_mult( + stream_ordering_manager = await self._stream_id_gen.get_next_mult( len(events_and_contexts) ) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 0e3b8739c6..a488e0924b 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -1182,7 +1182,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - with self._group_updates_id_gen.get_next() as next_id: + with await self._group_updates_id_gen.get_next() as next_id: res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 4e3ec02d14..c9f655dfb7 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -23,7 +23,7 @@ from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): async def update_presence(self, presence_states): - stream_ordering_manager = self._presence_id_gen.get_next_mult( + stream_ordering_manager = await self._presence_id_gen.get_next_mult( len(presence_states) ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index a585e54812..2fb5b02d7d 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore): ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() if before or after: @@ -560,7 +560,7 @@ class PushRuleStore(PushRulesWorkerStore): txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore): ) async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None: - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -646,7 +646,7 @@ class PushRuleStore(PushRulesWorkerStore): data={"actions": actions_json}, ) - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 1126fd0751..c388468273 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore): last_stream_ordering, profile_tag="", ) -> None: - with self._pushers_id_gen.get_next() as stream_id: + with await self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry await self.db_pool.simple_upsert( @@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore): }, ) - with self._pushers_id_gen.get_next() as stream_id: + with await self._pushers_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "delete_pusher", delete_pusher_txn, stream_id ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 19ad1c056f..6821476ee0 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -520,8 +520,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "insert_receipt_conv", graph_to_linear ) - stream_id_manager = self._receipts_id_gen.get_next() - with stream_id_manager as stream_id: + with await self._receipts_id_gen.get_next() as stream_id: event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 7d3ac47261..b3772be2b2 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1129,7 +1129,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "store_room_txn", store_room_txn, next_id ) @@ -1196,7 +1196,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) @@ -1276,7 +1276,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index ade7abc927..0c34bbf21a 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 0bf772d4d1..ddb5c8c60c 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -80,7 +80,7 @@ class StreamIdGenerator(object): upwards, -1 to grow downwards. Usage: - with stream_id_gen.get_next() as stream_id: + with await stream_id_gen.get_next() as stream_id: # ... persist event ... """ @@ -95,10 +95,10 @@ class StreamIdGenerator(object): ) self._unfinished_ids = deque() # type: Deque[int] - def get_next(self): + async def get_next(self): """ Usage: - with stream_id_gen.get_next() as stream_id: + with await stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -117,10 +117,10 @@ class StreamIdGenerator(object): return manager() - def get_next_mult(self, n): + async def get_next_mult(self, n): """ Usage: - with stream_id_gen.get_next(n) as stream_ids: + with await stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: -- cgit 1.5.1 From 4c6c56dc58aba7af92f531655c2355d8f25e529c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 26 Aug 2020 07:19:32 -0400 Subject: Convert simple_select_one and simple_select_one_onecol to async (#8162) --- changelog.d/8162.misc | 1 + synapse/storage/database.py | 36 +++++++++++--- synapse/storage/databases/main/devices.py | 14 +++--- synapse/storage/databases/main/directory.py | 4 +- synapse/storage/databases/main/e2e_room_keys.py | 8 ++-- synapse/storage/databases/main/events_worker.py | 10 ++-- synapse/storage/databases/main/group_server.py | 18 ++++--- synapse/storage/databases/main/media_repository.py | 13 +++-- .../storage/databases/main/monthly_active_users.py | 15 +++--- synapse/storage/databases/main/profile.py | 17 ++++--- synapse/storage/databases/main/receipts.py | 6 ++- synapse/storage/databases/main/registration.py | 10 ++-- synapse/storage/databases/main/rejections.py | 5 +- synapse/storage/databases/main/room.py | 10 ++-- synapse/storage/databases/main/state.py | 4 +- synapse/storage/databases/main/stats.py | 10 ++-- synapse/storage/databases/main/user_directory.py | 9 ++-- tests/handlers/test_profile.py | 56 +++++++++++++++++----- tests/handlers/test_typing.py | 4 +- tests/module_api/test_api.py | 2 +- tests/storage/test_base.py | 28 ++++++----- tests/storage/test_devices.py | 8 ++-- tests/storage/test_profile.py | 23 +++++++-- tests/storage/test_registration.py | 2 +- tests/storage/test_room.py | 20 ++++++-- 25 files changed, 220 insertions(+), 113 deletions(-) create mode 100644 changelog.d/8162.misc (limited to 'synapse/storage/databases/main/group_server.py') diff --git a/changelog.d/8162.misc b/changelog.d/8162.misc new file mode 100644 index 0000000000..e26764dea1 --- /dev/null +++ b/changelog.d/8162.misc @@ -0,0 +1 @@ + Convert various parts of the codebase to async/await. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index bc327e344e..181c3ec249 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -29,9 +29,11 @@ from typing import ( Tuple, TypeVar, Union, + overload, ) from prometheus_client import Histogram +from typing_extensions import Literal from twisted.enterprise import adbapi from twisted.internet import defer @@ -1020,14 +1022,36 @@ class DatabasePool(object): return txn.execute_batch(sql, args) - def simple_select_one( + @overload + async def simple_select_one( + self, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: Literal[False] = False, + desc: str = "simple_select_one", + ) -> Dict[str, Any]: + ... + + @overload + async def simple_select_one( + self, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: Literal[True] = True, + desc: str = "simple_select_one", + ) -> Optional[Dict[str, Any]]: + ... + + async def simple_select_one( self, table: str, keyvalues: Dict[str, Any], retcols: Iterable[str], allow_none: bool = False, desc: str = "simple_select_one", - ) -> defer.Deferred: + ) -> Optional[Dict[str, Any]]: """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -1038,18 +1062,18 @@ class DatabasePool(object): allow_none: If true, return None instead of failing if the SELECT statement returns no rows """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none ) - def simple_select_one_onecol( + async def simple_select_one_onecol( self, table: str, keyvalues: Dict[str, Any], retcol: Iterable[str], allow_none: bool = False, desc: str = "simple_select_one_onecol", - ) -> defer.Deferred: + ) -> Optional[Any]: """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it. @@ -1061,7 +1085,7 @@ class DatabasePool(object): statement returns no rows desc: description of the transaction, for logging and metrics """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_one_onecol_txn, table, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 03b45dbc4d..a811a39eb5 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): - def get_device(self, user_id: str, device_id: str): + async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: """Retrieve a device. Only returns devices that are not marked as hidden. @@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore): user_id: The ID of the user which owns the device device_id: The ID of the device to retrieve Returns: - defer.Deferred for a dict containing the device information + A dict containing the device information Raises: StoreError: if the device is not found """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore): ) @cached(max_entries=10000) - def get_device_list_last_stream_id_for_remote(self, user_id: str): + async def get_device_list_last_stream_id_for_remote( + self, user_id: str + ) -> Optional[Any]: """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, retcol="stream_id", diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 037e02603c..301d5d845a 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore): return RoomAliasMapping(room_id, room_alias.to_string(), servers) - def get_room_alias_creator(self, room_alias): - return self.db_pool.simple_select_one_onecol( + async def get_room_alias_creator(self, room_alias: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="room_aliases", keyvalues={"room_alias": room_alias}, retcol="creator", diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 2eeb9f97dc..46c3e33cc6 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -223,15 +223,15 @@ class EndToEndRoomKeyStore(SQLBaseStore): return ret - def count_e2e_room_keys(self, user_id, version): + async def count_e2e_room_keys(self, user_id: str, version: str) -> int: """Get the number of keys in a backup version. Args: - user_id (str): the user whose backup we're querying - version (str): the version ID of the backup we're querying about + user_id: the user whose backup we're querying + version: the version ID of the backup we're querying about """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="e2e_room_keys", keyvalues={"user_id": user_id, "version": version}, retcol="COUNT(*)", diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e1241a724b..d59d73938a 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -119,19 +119,19 @@ class EventsWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) - def get_received_ts(self, event_id): + async def get_received_ts(self, event_id: str) -> Optional[int]: """Get received_ts (when it was persisted) for the event. Raises an exception for unknown events. Args: - event_id (str) + event_id: The event ID to query. Returns: - Deferred[int|None]: Timestamp in milliseconds, or None for events - that were persisted before received_ts was implemented. + Timestamp in milliseconds, or None for events that were persisted + before received_ts was implemented. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="received_ts", diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index a488e0924b..c39864f59f 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json @@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = "" class GroupServerWorkerStore(SQLBaseStore): - def get_group(self, group_id): - return self.db_pool.simple_select_one( + async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="groups", keyvalues={"group_id": group_id}, retcols=( @@ -351,8 +351,10 @@ class GroupServerWorkerStore(SQLBaseStore): ) return bool(result) - def is_user_admin_in_group(self, group_id, user_id): - return self.db_pool.simple_select_one_onecol( + async def is_user_admin_in_group( + self, group_id: str, user_id: str + ) -> Optional[bool]: + return await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="is_admin", @@ -360,10 +362,12 @@ class GroupServerWorkerStore(SQLBaseStore): desc="is_user_admin_in_group", ) - def is_user_invited_to_local_group(self, group_id, user_id): + async def is_user_invited_to_local_group( + self, group_id: str, user_id: str + ) -> Optional[bool]: """Has the group server invited a user? """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 80fc1cd009..4ae255ebd8 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Optional + from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(MediaRepositoryStore, self).__init__(database, db_conn, hs) - def get_local_media(self, media_id): + async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: """Get the metadata for a local piece of media + Returns: None if the media_id doesn't exist. """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_local_thumbnail", ) - def get_cached_remote_media(self, origin, media_id): - return self.db_pool.simple_select_one( + async def get_cached_remote_media( + self, origin, media_id: str + ) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e71cdd2cb4..fe30552c08 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): return users @cached(num_args=1) - def user_last_seen_monthly_active(self, user_id): + async def user_last_seen_monthly_active(self, user_id: str) -> int: """ - Checks if a given user is part of the monthly active user group - Arguments: - user_id (str): user to add/update - Return: - Deferred[int] : timestamp since last seen, None if never seen + Checks if a given user is part of the monthly active user group + Arguments: + user_id: user to add/update + + Return: + Timestamp since last seen, None if never seen """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="monthly_active_users", keyvalues={"user_id": user_id}, retcol="timestamp", diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index b8261357d4..b8233c4848 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore @@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo class ProfileWorkerStore(SQLBaseStore): - async def get_profileinfo(self, user_localpart): + async def get_profileinfo(self, user_localpart: str) -> ProfileInfo: try: profile = await self.db_pool.simple_select_one( table="profiles", @@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore): avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) - def get_profile_displayname(self, user_localpart): - return self.db_pool.simple_select_one_onecol( + async def get_profile_displayname(self, user_localpart: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="displayname", desc="get_profile_displayname", ) - def get_profile_avatar_url(self, user_localpart): - return self.db_pool.simple_select_one_onecol( + async def get_profile_avatar_url(self, user_localpart: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="avatar_url", desc="get_profile_avatar_url", ) - def get_from_remote_profile_cache(self, user_id): - return self.db_pool.simple_select_one( + async def get_from_remote_profile_cache( + self, user_id: str + ) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, retcols=("displayname", "avatar_url"), diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 6821476ee0..cea5ac9a68 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -71,8 +71,10 @@ class ReceiptsWorkerStore(SQLBaseStore): ) @cached(num_args=3) - def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self.db_pool.simple_select_one_onecol( + async def get_last_receipt_event_id_for_user( + self, user_id: str, room_id: str, receipt_type: str + ) -> Optional[str]: + return await self.db_pool.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 321a51cc6a..eced53d470 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,7 +17,7 @@ import logging import re -from typing import Awaitable, Dict, List, Optional +from typing import Any, Awaitable, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore): ) @cached() - def get_user_by_id(self, user_id): - return self.db_pool.simple_select_one( + async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ @@ -1259,12 +1259,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="del_user_pending_deactivation", ) - def get_user_pending_deactivation(self): + async def get_user_pending_deactivation(self) -> Optional[str]: """ Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py index cf9ba51205..1e361aaa9a 100644 --- a/synapse/storage/databases/main/rejections.py +++ b/synapse/storage/databases/main/rejections.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Optional from synapse.storage._base import SQLBaseStore @@ -21,8 +22,8 @@ logger = logging.getLogger(__name__) class RejectionsStore(SQLBaseStore): - def get_rejection_reason(self, event_id): - return self.db_pool.simple_select_one_onecol( + async def get_rejection_reason(self, event_id: str) -> Optional[str]: + return await self.db_pool.simple_select_one_onecol( table="rejections", retcol="reason", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index b3772be2b2..97ecdb16e4 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -73,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore): self.config = hs.config - def get_room(self, room_id): + async def get_room(self, room_id: str) -> dict: """Retrieve a room. Args: - room_id (str): The ID of the room to retrieve. + room_id: The ID of the room to retrieve. Returns: A dict containing the room information, or None if the room is unknown. """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, retcols=("room_id", "is_public", "creator"), @@ -330,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore): return ret_val @cached(max_entries=10000) - def is_room_blocked(self, room_id): - return self.db_pool.simple_select_one_onecol( + async def is_room_blocked(self, room_id: str) -> Optional[bool]: + return await self.db_pool.simple_select_one_onecol( table="blocked_rooms", keyvalues={"room_id": room_id}, retcol="1", diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 991233a9bc..458f169617 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return event.content.get("canonical_alias") @cached(max_entries=50000) - def _get_state_group_for_event(self, event_id): - return self.db_pool.simple_select_one_onecol( + async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: + return await self.db_pool.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, retcol="state_group", diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 802c9019b9..9fe97af56a 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore): return len(rooms_to_work_on) - def get_stats_positions(self): + async def get_stats_positions(self) -> int: """ Returns the stats processor positions. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="stats_incremental_position", keyvalues={}, retcol="stream_id", @@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore): return slice_list @cached() - def get_earliest_token_for_stats(self, stats_type, id): + async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int: """ Fetch the "earliest token". This is used by the room stats delta processor to ignore deltas that have been processed between the @@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore): being calculated. Returns: - Deferred[int] + The earliest token. """ table, id_col = TYPE_TO_TABLE[stats_type] - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( "%s_current" % (table,), keyvalues={id_col: id}, retcol="completed_delta_stream_id", diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index af21fe457a..20cbcd851c 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -15,6 +15,7 @@ import logging import re +from typing import Any, Dict, Optional from synapse.api.constants import EventTypes, JoinRules from synapse.storage.database import DatabasePool @@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) @cached() - def get_user_in_directory(self, user_id): - return self.db_pool.simple_select_one( + async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, retcols=("display_name", "avatar_url"), @@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - def get_user_directory_stream_pos(self): - return self.db_pool.simple_select_one_onecol( + async def get_user_directory_stream_pos(self) -> int: + return await self.db_pool.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, retcol="stream_id", diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index b609b30d4a..60ebc95f3e 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -71,7 +71,9 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_name(self): - yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.frank.localpart, "Frank") + ) displayname = yield defer.ensureDeferred( self.handler.get_displayname(self.frank) @@ -104,7 +106,12 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), + "Frank", ) @defer.inlineCallbacks @@ -112,10 +119,17 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.enable_set_displayname = False # Setting displayname for the first time is allowed - yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.frank.localpart, "Frank") + ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), + "Frank", ) # Setting displayname a second time is forbidden @@ -158,7 +172,9 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_incoming_fed_query(self): yield defer.ensureDeferred(self.store.create_profile("caroline")) - yield self.store.set_profile_displayname("caroline", "Caroline") + yield defer.ensureDeferred( + self.store.set_profile_displayname("caroline", "Caroline") + ) response = yield defer.ensureDeferred( self.query_handlers["profile"]( @@ -170,8 +186,10 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_avatar(self): - yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) ) avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank)) @@ -188,7 +206,11 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/pic.gif", ) @@ -202,7 +224,11 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/me.png", ) @@ -211,12 +237,18 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.enable_set_avatar_url = False # Setting displayname for the first time is allowed - yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/me.png", ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index e01de158e5..834b4a0af6 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -144,9 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_users_in_room = get_users_in_room - self.datastore.get_user_directory_stream_pos.return_value = ( + self.datastore.get_user_directory_stream_pos.side_effect = ( # we deliberately return a non-None stream pos to avoid doing an initial_spam - defer.succeed(1) + lambda: make_awaitable(1) ) self.datastore.get_current_state_deltas.return_value = (0, None) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 807cd65dd6..04de0b9dbe 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -35,7 +35,7 @@ class ModuleApiTestCase(HomeserverTestCase): # Check that the new user exists with all provided attributes self.assertEqual(user_id, "@bob:test") self.assertTrue(access_token) - self.assertTrue(self.store.get_user_by_id(user_id)) + self.assertTrue(self.get_success(self.store.get_user_by_id(user_id))) # Check that the email was assigned emails = self.get_success(self.store.user_get_threepids(user_id)) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 13bcac743a..bf22540d99 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -97,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore.db_pool.simple_select_one_onecol( - table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" + value = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one_onecol( + table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" + ) ) self.assertEquals("Value", value) @@ -111,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore.db_pool.simple_select_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - retcols=["colA", "colB", "colC"], + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one( + table="tablename", + keyvalues={"keycol": "TheKey"}, + retcols=["colA", "colB", "colC"], + ) ) self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) @@ -127,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore.db_pool.simple_select_one( - table="tablename", - keyvalues={"keycol": "Not here"}, - retcols=["colA"], - allow_none=True, + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one( + table="tablename", + keyvalues={"keycol": "Not here"}, + retcols=["colA"], + allow_none=True, + ) ) self.assertFalse(ret) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 87ed8f8cd1..34ae8c9da7 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -38,7 +38,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): self.store.store_device("user_id", "device_id", "display_name") ) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertDictContainsSubset( { "user_id": "user_id", @@ -111,12 +111,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase): self.store.store_device("user_id", "device_id", "display_name 1") ) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 1", res["display_name"]) # do a no-op first yield defer.ensureDeferred(self.store.update_device("user_id", "device_id")) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 1", res["display_name"]) # do the update @@ -127,7 +127,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) # check it worked - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 2", res["display_name"]) @defer.inlineCallbacks diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 9d5b8aa47d..3fd0a38cf5 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -35,21 +35,34 @@ class ProfileStoreTestCase(unittest.TestCase): def test_displayname(self): yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) - yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank") + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.u_frank.localpart, "Frank") + ) self.assertEquals( - "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart)) + "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.u_frank.localpart) + ) + ), ) @defer.inlineCallbacks def test_avatar_url(self): yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) - yield self.store.set_profile_avatar_url( - self.u_frank.localpart, "http://my.site/here" + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.u_frank.localpart, "http://my.site/here" + ) ) self.assertEquals( "http://my.site/here", - (yield self.store.get_profile_avatar_url(self.u_frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.u_frank.localpart) + ) + ), ) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 58f827d8d3..70c55cd650 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -53,7 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase): "user_type": None, "deactivated": 0, }, - (yield self.store.get_user_by_id(self.user_id)), + (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))), ) @defer.inlineCallbacks diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index d07b985a8e..bc8400f240 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -54,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase): "creator": self.u_creator.to_string(), "is_public": True, }, - (yield self.store.get_room(self.room.to_string())), + (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))), ) @defer.inlineCallbacks def test_get_room_unknown_room(self): - self.assertIsNone((yield self.store.get_room("!uknown:test")),) + self.assertIsNone( + (yield defer.ensureDeferred(self.store.get_room("!uknown:test"))) + ) @defer.inlineCallbacks def test_get_room_with_stats(self): @@ -69,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase): "creator": self.u_creator.to_string(), "public": True, }, - (yield self.store.get_room_with_stats(self.room.to_string())), + ( + yield defer.ensureDeferred( + self.store.get_room_with_stats(self.room.to_string()) + ) + ), ) @defer.inlineCallbacks def test_get_room_with_stats_unknown_room(self): - self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),) + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_room_with_stats("!uknown:test") + ) + ), + ) class RoomEventsStoreTestCase(unittest.TestCase): -- cgit 1.5.1 From 4a739c73b404284253a548f60197e70c6c385645 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 07:08:38 -0400 Subject: Convert simple_update* and simple_select* to async (#8173) --- changelog.d/8173.misc | 1 + synapse/handlers/room.py | 6 +-- synapse/storage/database.py | 29 ++++++------ synapse/storage/databases/main/__init__.py | 8 ++-- synapse/storage/databases/main/directory.py | 6 +-- synapse/storage/databases/main/e2e_room_keys.py | 26 ++++++---- synapse/storage/databases/main/event_federation.py | 4 +- synapse/storage/databases/main/group_server.py | 55 +++++++++++++--------- synapse/storage/databases/main/media_repository.py | 16 ++++--- synapse/storage/databases/main/profile.py | 18 ++++--- synapse/storage/databases/main/receipts.py | 8 ++-- synapse/storage/databases/main/registration.py | 22 +++++---- synapse/storage/databases/main/room.py | 4 +- synapse/storage/databases/main/user_directory.py | 4 +- tests/handlers/test_stats.py | 4 +- tests/storage/test_base.py | 26 ++++++---- tests/storage/test_directory.py | 6 ++- tests/storage/test_main.py | 4 +- tests/test_federation.py | 50 +++++++++----------- 19 files changed, 164 insertions(+), 133 deletions(-) create mode 100644 changelog.d/8173.misc (limited to 'synapse/storage/databases/main/group_server.py') diff --git a/changelog.d/8173.misc b/changelog.d/8173.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8173.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e4788ef86b..236a37f777 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -51,7 +51,7 @@ from synapse.types import ( create_requester, ) from synapse.util import stringutils -from synapse.util.async_helpers import Linearizer, maybe_awaitable +from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client @@ -1329,9 +1329,7 @@ class RoomShutdownHandler(object): ratelimit=False, ) - aliases_for_room = await maybe_awaitable( - self.store.get_aliases_for_room(room_id) - ) + aliases_for_room = await self.store.get_aliases_for_room(room_id) await self.store.update_aliases_for_room( room_id, new_room_id, requester_user_id diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 181c3ec249..38010af600 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1132,13 +1132,13 @@ class DatabasePool(object): return [r[0] for r in txn] - def simple_select_onecol( + async def simple_select_onecol( self, table: str, keyvalues: Optional[Dict[str, Any]], retcol: str, desc: str = "simple_select_onecol", - ) -> defer.Deferred: + ) -> List[Any]: """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. @@ -1148,19 +1148,19 @@ class DatabasePool(object): retcol: column whos value we wish to retrieve. Returns: - Deferred: Results in a list + Results in a list """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_onecol_txn, table, keyvalues, retcol ) - def simple_select_list( + async def simple_select_list( self, table: str, keyvalues: Optional[Dict[str, Any]], retcols: Iterable[str], desc: str = "simple_select_list", - ) -> defer.Deferred: + ) -> List[Dict[str, Any]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1170,10 +1170,11 @@ class DatabasePool(object): column names and values to select the rows with, or None to not apply a WHERE clause. retcols: the names of the columns to return + Returns: - defer.Deferred: resolves to list[dict[str, Any]] + A list of dictionaries. """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_list_txn, table, keyvalues, retcols ) @@ -1299,14 +1300,14 @@ class DatabasePool(object): txn.execute(sql, values) return cls.cursor_to_dict(txn) - def simple_update( + async def simple_update( self, table: str, keyvalues: Dict[str, Any], updatevalues: Dict[str, Any], desc: str, - ) -> defer.Deferred: - return self.runInteraction( + ) -> int: + return await self.runInteraction( desc, self.simple_update_txn, table, keyvalues, updatevalues ) @@ -1332,13 +1333,13 @@ class DatabasePool(object): return txn.rowcount - def simple_update_one( + async def simple_update_one( self, table: str, keyvalues: Dict[str, Any], updatevalues: Dict[str, Any], desc: str = "simple_update_one", - ) -> defer.Deferred: + ) -> None: """Executes an UPDATE query on the named table, setting new values for columns in a row matching the key values. @@ -1347,7 +1348,7 @@ class DatabasePool(object): keyvalues: dict of column names and values to select the row with updatevalues: dict giving column names and values to update """ - return self.runInteraction( + await self.runInteraction( desc, self.simple_update_one_txn, table, keyvalues, updatevalues ) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 0934ae276c..8b9b6eb472 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -18,6 +18,7 @@ import calendar import logging import time +from typing import Any, Dict, List from synapse.api.constants import PresenceState from synapse.config.homeserver import HomeServerConfig @@ -476,14 +477,13 @@ class DataStore( "generate_user_daily_visits", _generate_user_daily_visits ) - def get_users(self): + async def get_users(self) -> List[Dict[str, Any]]: """Function to retrieve a list of users in users table. - Args: Returns: - defer.Deferred: resolves to list[dict[str, Any]] + A list of dictionaries representing users. """ - return self.db_pool.simple_select_list( + return await self.db_pool.simple_select_list( table="users", keyvalues={}, retcols=[ diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 301d5d845a..405b5eafa5 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -14,7 +14,7 @@ # limitations under the License. from collections import namedtuple -from typing import Iterable, Optional +from typing import Iterable, List, Optional from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore @@ -68,8 +68,8 @@ class DirectoryWorkerStore(SQLBaseStore): ) @cached(max_entries=5000) - def get_aliases_for_room(self, room_id): - return self.db_pool.simple_select_onecol( + async def get_aliases_for_room(self, room_id: str) -> List[str]: + return await self.db_pool.simple_select_onecol( "room_aliases", {"room_id": room_id}, "room_alias", diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 46c3e33cc6..82f9d870fd 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace from synapse.storage._base import SQLBaseStore, db_to_json @@ -368,18 +370,22 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - def update_e2e_room_keys_version( - self, user_id, version, info=None, version_etag=None - ): + async def update_e2e_room_keys_version( + self, + user_id: str, + version: str, + info: Optional[dict] = None, + version_etag: Optional[int] = None, + ) -> None: """Update a given backup version Args: - user_id(str): the user whose backup version we're updating - version(str): the version ID of the backup version we're updating - info (dict): the new backup version info to store. If None, then - the backup version info is not updated - version_etag (Optional[int]): etag of the keys in the backup. If - None, then the etag is not updated + user_id: the user whose backup version we're updating + version: the version ID of the backup version we're updating + info: the new backup version info to store. If None, then the backup + version info is not updated. + version_etag: etag of the keys in the backup. If None, then the etag + is not updated. """ updatevalues = {} @@ -389,7 +395,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): updatevalues["etag"] = version_etag if updatevalues: - return self.db_pool.simple_update( + await self.db_pool.simple_update( table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": version}, updatevalues=updatevalues, diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index e6a97b018c..6e5761c7b7 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -368,8 +368,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) @cached(max_entries=5000, iterable=True) - def get_latest_event_ids_in_room(self, room_id): - return self.db_pool.simple_select_onecol( + async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]: + return await self.db_pool.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index c39864f59f..e3ead71853 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -44,24 +44,26 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_group", ) - def get_users_in_group(self, group_id, include_private=False): + async def get_users_in_group( + self, group_id: str, include_private: bool = False + ) -> List[Dict[str, Any]]: # TODO: Pagination keyvalues = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True - return self.db_pool.simple_select_list( + return await self.db_pool.simple_select_list( table="group_users", keyvalues=keyvalues, retcols=("user_id", "is_public", "is_admin"), desc="get_users_in_group", ) - def get_invited_users_in_group(self, group_id): + async def get_invited_users_in_group(self, group_id: str) -> List[str]: # TODO: Pagination - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="group_invites", keyvalues={"group_id": group_id}, retcol="user_id", @@ -265,15 +267,14 @@ class GroupServerWorkerStore(SQLBaseStore): return role - def get_local_groups_for_room(self, room_id): + async def get_local_groups_for_room(self, room_id: str) -> List[str]: """Get all of the local group that contain a given room Args: - room_id (str): The ID of a room + room_id: The ID of a room Returns: - Deferred[list[str]]: A twisted.Deferred containing a list of group ids - containing this room + A list of group ids containing this room """ - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="group_rooms", keyvalues={"room_id": room_id}, retcol="group_id", @@ -422,10 +423,10 @@ class GroupServerWorkerStore(SQLBaseStore): "get_users_membership_info_in_group", _get_users_membership_in_group_txn ) - def get_publicised_groups_for_user(self, user_id): + async def get_publicised_groups_for_user(self, user_id: str) -> List[str]: """Get all groups a user is publicising """ - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, retcol="group_id", @@ -466,8 +467,8 @@ class GroupServerWorkerStore(SQLBaseStore): return None - def get_joined_groups(self, user_id): - return self.db_pool.simple_select_onecol( + async def get_joined_groups(self, user_id: str) -> List[str]: + return await self.db_pool.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join"}, retcol="group_id", @@ -585,14 +586,14 @@ class GroupServerWorkerStore(SQLBaseStore): class GroupServerStore(GroupServerWorkerStore): - def set_group_join_policy(self, group_id, join_policy): + async def set_group_join_policy(self, group_id: str, join_policy: str) -> None: """Set the join policy of a group. join_policy can be one of: * "invite" * "open" """ - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues={"join_policy": join_policy}, @@ -1050,8 +1051,10 @@ class GroupServerStore(GroupServerWorkerStore): desc="add_room_to_group", ) - def update_room_in_group_visibility(self, group_id, room_id, is_public): - return self.db_pool.simple_update( + async def update_room_in_group_visibility( + self, group_id: str, room_id: str, is_public: bool + ) -> int: + return await self.db_pool.simple_update( table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, updatevalues={"is_public": is_public}, @@ -1076,10 +1079,12 @@ class GroupServerStore(GroupServerWorkerStore): "remove_room_from_group", _remove_room_from_group_txn ) - def update_group_publicity(self, group_id, user_id, publicise): + async def update_group_publicity( + self, group_id: str, user_id: str, publicise: bool + ) -> None: """Update whether the user is publicising their membership of the group """ - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"is_publicised": publicise}, @@ -1218,20 +1223,24 @@ class GroupServerStore(GroupServerWorkerStore): desc="update_group_profile", ) - def update_attestation_renewal(self, group_id, user_id, attestation): + async def update_attestation_renewal( + self, group_id: str, user_id: str, attestation: dict + ) -> None: """Update an attestation that we have renewed """ - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, desc="update_attestation_renewal", ) - def update_remote_attestion(self, group_id, user_id, attestation): + async def update_remote_attestion( + self, group_id: str, user_id: str, attestation: dict + ) -> None: """Update an attestation that a remote has renewed """ - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 4ae255ebd8..fc223f5a2a 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -84,9 +84,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_local_media", ) - def mark_local_media_as_safe(self, media_id: str): + async def mark_local_media_as_safe(self, media_id: str) -> None: """Mark a local media as safe from quarantining.""" - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="local_media_repository", keyvalues={"media_id": media_id}, updatevalues={"safe_from_quarantine": True}, @@ -158,8 +158,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_url_cache", ) - def get_local_media_thumbnails(self, media_id): - return self.db_pool.simple_select_list( + async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]: + return await self.db_pool.simple_select_list( "local_media_repository_thumbnails", {"media_id": media_id}, ( @@ -271,8 +271,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "update_cached_last_access_time", update_cache_txn ) - def get_remote_media_thumbnails(self, origin, media_id): - return self.db_pool.simple_select_list( + async def get_remote_media_thumbnails( + self, origin: str, media_id: str + ) -> List[Dict[str, Any]]: + return await self.db_pool.simple_select_list( "remote_media_cache_thumbnails", {"media_origin": origin, "media_id": media_id}, ( diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index b8233c4848..858fd92420 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -71,16 +71,20 @@ class ProfileWorkerStore(SQLBaseStore): table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) - def set_profile_displayname(self, user_localpart, new_displayname): - return self.db_pool.simple_update_one( + async def set_profile_displayname( + self, user_localpart: str, new_displayname: str + ) -> None: + await self.db_pool.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"displayname": new_displayname}, desc="set_profile_displayname", ) - def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self.db_pool.simple_update_one( + async def set_profile_avatar_url( + self, user_localpart: str, new_avatar_url: str + ) -> None: + await self.db_pool.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"avatar_url": new_avatar_url}, @@ -106,8 +110,10 @@ class ProfileStore(ProfileWorkerStore): desc="add_remote_profile_cache", ) - def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self.db_pool.simple_update( + async def update_remote_profile_cache( + self, user_id: str, displayname: str, avatar_url: str + ) -> int: + return await self.db_pool.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, updatevalues={ diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index cea5ac9a68..436f22ad2d 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -16,7 +16,7 @@ import abc import logging -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from twisted.internet import defer @@ -62,8 +62,10 @@ class ReceiptsWorkerStore(SQLBaseStore): return {r["user_id"] for r in receipts} @cached(num_args=2) - def get_receipts_for_room(self, room_id, receipt_type): - return self.db_pool.simple_select_list( + async def get_receipts_for_room( + self, room_id: str, receipt_type: str + ) -> List[Dict[str, Any]]: + return await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"room_id": room_id, "receipt_type": receipt_type}, retcols=("user_id", "event_id"), diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index eced53d470..48bda66f3e 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -578,20 +578,20 @@ class RegistrationWorkerStore(SQLBaseStore): desc="add_user_bound_threepid", ) - def user_get_bound_threepids(self, user_id): + async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]: """Get the threepids that a user has bound to an identity server through the homeserver The homeserver remembers where binds to an identity server occurred. Using this method can retrieve those threepids. Args: - user_id (str): The ID of the user to retrieve threepids for + user_id: The ID of the user to retrieve threepids for Returns: - Deferred[list[dict]]: List of dictionaries containing the following: + List of dictionaries containing the following keys: medium (str): The medium of the threepid (e.g "email") address (str): The address of the threepid (e.g "bob@example.com") """ - return self.db_pool.simple_select_list( + return await self.db_pool.simple_select_list( table="user_threepid_id_server", keyvalues={"user_id": user_id}, retcols=["medium", "address"], @@ -623,19 +623,21 @@ class RegistrationWorkerStore(SQLBaseStore): desc="remove_user_bound_threepid", ) - def get_id_servers_user_bound(self, user_id, medium, address): + async def get_id_servers_user_bound( + self, user_id: str, medium: str, address: str + ) -> List[str]: """Get the list of identity servers that the server proxied bind requests to for given user and threepid Args: - user_id (str) - medium (str) - address (str) + user_id: The user to query for identity servers. + medium: The medium to query for identity servers. + address: The address to query for identity servers. Returns: - Deferred[list[str]]: Resolves to a list of identity servers + A list of identity servers """ - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="user_threepid_id_server", keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 97ecdb16e4..66d7135413 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -125,8 +125,8 @@ class RoomWorkerStore(SQLBaseStore): "get_room_with_stats", get_room_with_stats_txn, room_id ) - def get_public_room_ids(self): - return self.db_pool.simple_select_onecol( + async def get_public_room_ids(self) -> List[str]: + return await self.db_pool.simple_select_onecol( table="rooms", keyvalues={"is_public": True}, retcol="room_id", diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 20cbcd851c..a9f2e93614 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -537,8 +537,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): desc="get_user_in_directory", ) - def update_user_directory_stream_pos(self, stream_id): - return self.db_pool.simple_update_one( + async def update_user_directory_stream_pos(self, stream_id: str) -> None: + await self.db_pool.simple_update_one( table="user_directory_stream_pos", keyvalues={}, updatevalues={"stream_id": stream_id}, diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 88b05c23a0..a609f148c0 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -81,8 +81,8 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - def get_all_room_state(self): - return self.store.db_pool.simple_select_list( + async def get_all_room_state(self): + return await self.store.db_pool.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index bf22540d99..64abe8cc49 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -148,8 +148,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore.db_pool.simple_select_list( - table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_list( + table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] + ) ) self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) @@ -161,10 +163,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_update_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - updatevalues={"columnname": "New Value"}, + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_one( + table="tablename", + keyvalues={"keycol": "TheKey"}, + updatevalues={"columnname": "New Value"}, + ) ) self.mock_txn.execute.assert_called_with( @@ -176,10 +180,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_update_one( - table="tablename", - keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), - updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_one( + table="tablename", + keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), + updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), + ) ) self.mock_txn.execute.assert_called_with( diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index daac947cb2..da93ca3980 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -42,7 +42,11 @@ class DirectoryStoreTestCase(unittest.TestCase): self.assertEquals( ["#my-room:test"], - (yield self.store.get_aliases_for_room(self.room.to_string())), + ( + yield defer.ensureDeferred( + self.store.get_aliases_for_room(self.room.to_string()) + ) + ), ) @defer.inlineCallbacks diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index fbf8af940a..954338a592 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -36,7 +36,9 @@ class DataStoreTestCase(unittest.TestCase): def test_get_users_paginate(self): yield self.store.register_user(self.user.to_string(), "pass") yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) - yield self.store.set_profile_displayname(self.user.localpart, self.displayname) + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.user.localpart, self.displayname) + ) users, total = yield self.store.get_users_paginate( 0, 10, name="bc", guests=False diff --git a/tests/test_federation.py b/tests/test_federation.py index 4a4548433f..27a7fc9ed7 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -15,8 +15,9 @@ from mock import Mock -from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed +from twisted.internet.defer import succeed +from synapse.api.errors import FederationError from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext from synapse.types import Requester, UserID @@ -44,22 +45,17 @@ class MessageAcceptTests(unittest.HomeserverTestCase): user_id = UserID("us", "test") our_user = Requester(user_id, None, False, False, None, None) room_creator = self.homeserver.get_room_creation_handler() - room_deferred = ensureDeferred( + self.room_id = self.get_success( room_creator.create_room( our_user, room_creator._presets_dict["public_chat"], ratelimit=False ) - ) - self.reactor.advance(0.1) - self.room_id = self.successResultOf(room_deferred)[0]["room_id"] + )[0]["room_id"] self.store = self.homeserver.get_datastore() # Figure out what the most recent event is - most_recent = self.successResultOf( - maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, - self.room_id, - ) + most_recent = self.get_success( + self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id) )[0] join_event = make_event_from_dict( @@ -89,19 +85,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) # Send the join, it should return None (which is not an error) - d = ensureDeferred( - self.handler.on_receive_pdu( - "test.serv", join_event, sent_to_us_directly=True - ) + self.assertEqual( + self.get_success( + self.handler.on_receive_pdu( + "test.serv", join_event, sent_to_us_directly=True + ) + ), + None, ) - self.reactor.advance(1) - self.assertEqual(self.successResultOf(d), None) # Make sure we actually joined the room self.assertEqual( - self.successResultOf( - maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) - )[0], + self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0], "$join:test.serv", ) @@ -119,8 +114,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.http_client.post_json = post_json # Figure out what the most recent event is - most_recent = self.successResultOf( - maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) + most_recent = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) )[0] # Now lie about an event @@ -140,17 +135,14 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) with LoggingContext(request="lying_event"): - d = ensureDeferred( + failure = self.get_failure( self.handler.on_receive_pdu( "test.serv", lying_event, sent_to_us_directly=True - ) + ), + FederationError, ) - # Step the reactor, so the database fetches come back - self.reactor.advance(1) - # on_receive_pdu should throw an error - failure = self.failureResultOf(d) self.assertEqual( failure.value.args[0], ( @@ -160,8 +152,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) # Make sure the invalid event isn't there - extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) - self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") + extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) + self.assertEqual(extrem[0], "$join:test.serv") def test_retry_device_list_resync(self): """Tests that device lists are marked as stale if they couldn't be synced, and -- cgit 1.5.1 From 9b7ac03af3e7ceae7d1933db566ee407cfdef72d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 13:38:41 -0400 Subject: Convert calls of async database methods to async (#8166) --- changelog.d/8166.misc | 1 + synapse/federation/persistence.py | 16 +++++++----- synapse/federation/units.py | 4 +-- synapse/storage/databases/main/appservice.py | 6 ++--- synapse/storage/databases/main/devices.py | 4 +-- synapse/storage/databases/main/group_server.py | 30 ++++++++++++++++------ synapse/storage/databases/main/keys.py | 26 +++++++++++-------- synapse/storage/databases/main/media_repository.py | 22 ++++++++-------- synapse/storage/databases/main/openid.py | 6 +++-- synapse/storage/databases/main/profile.py | 10 +++++--- synapse/storage/databases/main/registration.py | 29 ++++++++++----------- synapse/storage/databases/main/room.py | 16 ++++++++---- synapse/storage/databases/main/stats.py | 10 ++++---- synapse/storage/databases/main/transactions.py | 18 +++++++------ 14 files changed, 114 insertions(+), 84 deletions(-) create mode 100644 changelog.d/8166.misc (limited to 'synapse/storage/databases/main/group_server.py') diff --git a/changelog.d/8166.misc b/changelog.d/8166.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8166.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index d68b4bd670..769cd5de28 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -21,7 +21,9 @@ These actions are mostly only used by the :py:mod:`.replication` module. import logging +from synapse.federation.units import Transaction from synapse.logging.utils import log_function +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -49,15 +51,15 @@ class TransactionActions(object): return self.store.get_received_txn_response(transaction.transaction_id, origin) @log_function - def set_response(self, origin, transaction, code, response): + async def set_response( + self, origin: str, transaction: Transaction, code: int, response: JsonDict + ) -> None: """ Persist how we responded to a transaction. - - Returns: - Deferred """ - if not transaction.transaction_id: + transaction_id = transaction.transaction_id # type: ignore + if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") - return self.store.set_received_txn_response( - transaction.transaction_id, origin, code, response + await self.store.set_received_txn_response( + transaction_id, origin, code, response ) diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 6b32e0dcbf..64d98fc8f6 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -107,9 +107,7 @@ class Transaction(JsonEncodedObject): if "edus" in kwargs and not kwargs["edus"]: del kwargs["edus"] - super(Transaction, self).__init__( - transaction_id=transaction_id, pdus=pdus, **kwargs - ) + super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs) @staticmethod def create_new(pdus, **kwargs): diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 77723f7d4d..92f56f1602 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -161,16 +161,14 @@ class ApplicationServiceTransactionWorkerStore( return result.get("state") return None - def set_appservice_state(self, service, state): + async def set_appservice_state(self, service, state) -> None: """Set the application service state. Args: service(ApplicationService): The service whose state to set. state(ApplicationServiceState): The connectivity state to apply. - Returns: - An Awaitable which resolves when the state was set successfully. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a811a39eb5..ecd3f3b310 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -716,11 +716,11 @@ class DeviceWorkerStore(SQLBaseStore): return {row["user_id"] for row in rows} - def mark_remote_user_device_cache_as_stale(self, user_id: str): + async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None: """Records that the server has reason to believe the cache of the devices for the remote users is out of date. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="device_lists_remote_resync", keyvalues={"user_id": user_id}, values={}, diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index e3ead71853..8acf254bf3 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -742,7 +742,13 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_room_from_summary", ) - def upsert_group_category(self, group_id, category_id, profile, is_public): + async def upsert_group_category( + self, + group_id: str, + category_id: str, + profile: Optional[JsonDict], + is_public: Optional[bool], + ) -> None: """Add/update room category for group """ insertion_values = {} @@ -758,7 +764,7 @@ class GroupServerStore(GroupServerWorkerStore): else: update_values["is_public"] = is_public - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, values=update_values, @@ -773,7 +779,13 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_group_category", ) - def upsert_group_role(self, group_id, role_id, profile, is_public): + async def upsert_group_role( + self, + group_id: str, + role_id: str, + profile: Optional[JsonDict], + is_public: Optional[bool], + ) -> None: """Add/remove user role """ insertion_values = {} @@ -789,7 +801,7 @@ class GroupServerStore(GroupServerWorkerStore): else: update_values["is_public"] = is_public - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, values=update_values, @@ -938,10 +950,10 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_user_from_summary", ) - def add_group_invite(self, group_id, user_id): + async def add_group_invite(self, group_id: str, user_id: str) -> None: """Record that the group server has invited a user """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, desc="add_group_invite", @@ -1044,8 +1056,10 @@ class GroupServerStore(GroupServerWorkerStore): "remove_user_from_group", _remove_user_from_group_txn ) - def add_room_to_group(self, group_id, room_id, is_public): - return self.db_pool.simple_insert( + async def add_room_to_group( + self, group_id: str, room_id: str, is_public: bool + ) -> None: + await self.db_pool.simple_insert( table="group_rooms", values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, desc="add_room_to_group", diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index fadcad51e7..1c0a049c55 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -140,22 +140,28 @@ class KeyStore(SQLBaseStore): for i in invalidations: invalidate((i,)) - def store_server_keys_json( - self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes - ): + async def store_server_keys_json( + self, + server_name: str, + key_id: str, + from_server: str, + ts_now_ms: int, + ts_expires_ms: int, + key_json_bytes: bytes, + ) -> None: """Stores the JSON bytes for a set of keys from a server The JSON should be signed by the originating server, the intermediate server, and by this server. Updates the value for the (server_name, key_id, from_server) triplet if one already existed. Args: - server_name (str): The name of the server. - key_id (str): The identifer of the key this JSON is for. - from_server (str): The server this JSON was fetched from. - ts_now_ms (int): The time now in milliseconds. - ts_valid_until_ms (int): The time when this json stops being valid. - key_json (bytes): The encoded JSON. + server_name: The name of the server. + key_id: The identifer of the key this JSON is for. + from_server: The server this JSON was fetched from. + ts_now_ms: The time now in milliseconds. + ts_valid_until_ms: The time when this json stops being valid. + key_json_bytes: The encoded JSON. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="server_keys_json", keyvalues={ "server_name": server_name, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 8361dd63d9..3919ecad69 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -60,7 +60,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_local_media", ) - def store_local_media( + async def store_local_media( self, media_id, media_type, @@ -69,8 +69,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): media_length, user_id, url_cache=None, - ): - return self.db_pool.simple_insert( + ) -> None: + await self.db_pool.simple_insert( "local_media_repository", { "media_id": media_id, @@ -141,10 +141,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) - def store_url_cache( + async def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "local_media_repository_url_cache", { "url": url, @@ -172,7 +172,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_local_media_thumbnails", ) - def store_local_thumbnail( + async def store_local_thumbnail( self, media_id, thumbnail_width, @@ -181,7 +181,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "local_media_repository_thumbnails", { "media_id": media_id, @@ -212,7 +212,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_cached_remote_media", ) - def store_cached_remote_media( + async def store_cached_remote_media( self, origin, media_id, @@ -222,7 +222,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): upload_name, filesystem_id, ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "remote_media_cache", { "media_origin": origin, @@ -288,7 +288,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_remote_media_thumbnails", ) - def store_remote_media_thumbnail( + async def store_remote_media_thumbnail( self, origin, media_id, @@ -299,7 +299,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "remote_media_cache_thumbnails", { "media_origin": origin, diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py index dcd1ff911a..4db8949da7 100644 --- a/synapse/storage/databases/main/openid.py +++ b/synapse/storage/databases/main/openid.py @@ -2,8 +2,10 @@ from synapse.storage._base import SQLBaseStore class OpenIdStore(SQLBaseStore): - def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self.db_pool.simple_insert( + async def insert_open_id_token( + self, token: str, ts_valid_until_ms: int, user_id: str + ) -> None: + await self.db_pool.simple_insert( table="open_id_tokens", values={ "token": token, diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 858fd92420..301875a672 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -66,8 +66,8 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_from_remote_profile_cache", ) - def create_profile(self, user_localpart): - return self.db_pool.simple_insert( + async def create_profile(self, user_localpart: str) -> None: + await self.db_pool.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) @@ -93,13 +93,15 @@ class ProfileWorkerStore(SQLBaseStore): class ProfileStore(ProfileWorkerStore): - def add_remote_profile_cache(self, user_id, displayname, avatar_url): + async def add_remote_profile_cache( + self, user_id: str, displayname: str, avatar_url: str + ) -> None: """Ensure we are caching the remote user's profiles. This should only be called when `is_subscribed_remote_profile_for_user` would return true for the user. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 48bda66f3e..28f7ae0430 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,7 +17,7 @@ import logging import re -from typing import Any, Awaitable, Dict, List, Optional +from typing import Any, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -549,23 +549,22 @@ class RegistrationWorkerStore(SQLBaseStore): desc="user_delete_threepids", ) - def add_user_bound_threepid(self, user_id, medium, address, id_server): + async def add_user_bound_threepid( + self, user_id: str, medium: str, address: str, id_server: str + ): """The server proxied a bind request to the given identity server on behalf of the given user. We need to remember this in case the user asks us to unbind the threepid. Args: - user_id (str) - medium (str) - address (str) - id_server (str) - - Returns: - Awaitable + user_id + medium + address + id_server """ # We need to use an upsert, in case they user had already bound the # threepid - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -1083,9 +1082,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - def record_user_external_id( + async def record_user_external_id( self, auth_provider: str, external_id: str, user_id: str - ) -> Awaitable: + ) -> None: """Record a mapping from an external user id to a mxid Args: @@ -1093,7 +1092,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): external_id: id on that system user_id: complete mxid that it is mapped to """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="user_external_ids", values={ "auth_provider": auth_provider, @@ -1237,12 +1236,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return res if res else False - def add_user_pending_deactivation(self, user_id): + async def add_user_pending_deactivation(self, user_id: str) -> None: """ Adds a user to the table of users who need to be parted from all the rooms they're in """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "users_pending_deactivation", values={"user_id": user_id}, desc="add_user_pending_deactivation", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 66d7135413..a92641c339 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -27,7 +27,7 @@ from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchStore -from synapse.types import ThirdPartyInstanceID +from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -1296,11 +1296,17 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): return self.db_pool.runInteraction("get_rooms", f) - def add_event_report( - self, room_id, event_id, user_id, reason, content, received_ts - ): + async def add_event_report( + self, + room_id: str, + event_id: str, + user_id: str, + reason: str, + content: JsonDict, + received_ts: int, + ) -> None: next_id = self._event_reports_id_gen.get_next() - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="event_reports", values={ "id": next_id, diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 9fe97af56a..7af2608ca4 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -16,7 +16,7 @@ import logging from itertools import chain -from typing import Tuple +from typing import Any, Dict, Tuple from twisted.internet.defer import DeferredLock @@ -222,11 +222,11 @@ class StatsStore(StateDeltasStore): desc="stats_incremental_position", ) - def update_room_state(self, room_id, fields): + async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None: """ Args: - room_id (str) - fields (dict[str:Any]) + room_id + fields """ # For whatever reason some of the fields may contain null bytes, which @@ -244,7 +244,7 @@ class StatsStore(StateDeltasStore): if field and "\0" in field: fields[col] = None - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 52668dbdf9..2efcc0dc66 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -21,6 +21,7 @@ from canonicaljson import encode_canonical_json from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool +from synapse.types import JsonDict from synapse.util.caches.expiringcache import ExpiringCache db_binary_type = memoryview @@ -98,20 +99,21 @@ class TransactionStore(SQLBaseStore): else: return None - def set_received_txn_response(self, transaction_id, origin, code, response_dict): - """Persist the response we returened for an incoming transaction, and + async def set_received_txn_response( + self, transaction_id: str, origin: str, code: int, response_dict: JsonDict + ) -> None: + """Persist the response we returned for an incoming transaction, and should return for subsequent transactions with the same transaction_id and origin. Args: - txn - transaction_id (str) - origin (str) - code (int) - response_json (str) + transaction_id: The incoming transaction ID. + origin: The origin server. + code: The response code. + response_dict: The response, to be encoded into JSON. """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="received_transactions", values={ "transaction_id": transaction_id, -- cgit 1.5.1 From b71d4a094c4370d0229fbd4545b85c049364ecf3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 14:16:41 -0400 Subject: Convert simple_delete to async/await. (#8191) --- changelog.d/8191.misc | 1 + synapse/storage/database.py | 63 ++++++++++++++++++++++---- synapse/storage/databases/main/group_server.py | 28 +++++++----- synapse/storage/databases/main/registration.py | 29 ++++++------ tests/storage/test_event_push_actions.py | 6 ++- 5 files changed, 90 insertions(+), 37 deletions(-) create mode 100644 changelog.d/8191.misc (limited to 'synapse/storage/databases/main/group_server.py') diff --git a/changelog.d/8191.misc b/changelog.d/8191.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8191.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ba4c0c9af6..7ab370efef 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -614,6 +614,7 @@ class DatabasePool(object): """Runs a single query for a result set. Args: + desc: description of the transaction, for logging and metrics decoder - The function which can resolve the cursor results to something meaningful. query - The query string to execute @@ -649,7 +650,7 @@ class DatabasePool(object): or_ignore: bool stating whether an exception should be raised when a conflicting row already exists. If True, False will be returned by the function instead - desc: string giving a description of the transaction + desc: description of the transaction, for logging and metrics Returns: Whether the row was inserted or not. Only useful when `or_ignore` is True @@ -686,7 +687,7 @@ class DatabasePool(object): Args: table: string giving the table name values: dict of new column names and values for them - desc: string giving a description of the transaction + desc: description of the transaction, for logging and metrics """ await self.runInteraction(desc, self.simple_insert_many_txn, table, values) @@ -700,7 +701,6 @@ class DatabasePool(object): txn: The transaction to use. table: string giving the table name values: dict of new column names and values for them - desc: string giving a description of the transaction """ if not values: return @@ -755,6 +755,7 @@ class DatabasePool(object): keyvalues: The unique key columns and their new values values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting + desc: description of the transaction, for logging and metrics lock: True to lock the table when doing the upsert. Returns: Native upserts always return None. Emulated upserts return True if a @@ -1081,6 +1082,7 @@ class DatabasePool(object): retcols: list of strings giving the names of the columns to return allow_none: If true, return None instead of failing if the SELECT statement returns no rows + desc: description of the transaction, for logging and metrics """ return await self.runInteraction( desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none @@ -1166,6 +1168,7 @@ class DatabasePool(object): table: table name keyvalues: column names and values to select the rows with retcol: column whos value we wish to retrieve. + desc: description of the transaction, for logging and metrics Returns: Results in a list @@ -1190,6 +1193,7 @@ class DatabasePool(object): column names and values to select the rows with, or None to not apply a WHERE clause. retcols: the names of the columns to return + desc: description of the transaction, for logging and metrics Returns: A list of dictionaries. @@ -1243,14 +1247,16 @@ class DatabasePool(object): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. - Filters rows by if value of `column` is in `iterable`. + Filters rows by whether the value of `column` is in `iterable`. Args: table: string giving the table name column: column name to test for inclusion against `iterable` iterable: list - keyvalues: dict of column names and values to select the rows with retcols: list of strings giving the names of the columns to return + keyvalues: dict of column names and values to select the rows with + desc: description of the transaction, for logging and metrics + batch_size: the number of rows for each select query """ results = [] # type: List[Dict[str, Any]] @@ -1291,7 +1297,7 @@ class DatabasePool(object): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. - Filters rows by if value of `column` is in `iterable`. + Filters rows by whether the value of `column` is in `iterable`. Args: txn: Transaction object @@ -1367,6 +1373,7 @@ class DatabasePool(object): table: string giving the table name keyvalues: dict of column names and values to select the row with updatevalues: dict giving column names and values to update + desc: description of the transaction, for logging and metrics """ await self.runInteraction( desc, self.simple_update_one_txn, table, keyvalues, updatevalues @@ -1426,6 +1433,7 @@ class DatabasePool(object): Args: table: string giving the table name keyvalues: dict of column names and values to select the row with + desc: description of the transaction, for logging and metrics """ await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) @@ -1451,13 +1459,38 @@ class DatabasePool(object): if txn.rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) - def simple_delete(self, table: str, keyvalues: Dict[str, Any], desc: str): - return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) + async def simple_delete( + self, table: str, keyvalues: Dict[str, Any], desc: str + ) -> int: + """Executes a DELETE query on the named table. + + Filters rows by the key-value pairs. + + Args: + table: string giving the table name + keyvalues: dict of column names and values to select the row with + desc: description of the transaction, for logging and metrics + + Returns: + The number of deleted rows. + """ + return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) @staticmethod def simple_delete_txn( txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any] ) -> int: + """Executes a DELETE query on the named table. + + Filters rows by the key-value pairs. + + Args: + table: string giving the table name + keyvalues: dict of column names and values to select the row with + + Returns: + The number of deleted rows. + """ sql = "DELETE FROM %s WHERE %s" % ( table, " AND ".join("%s = ?" % (k,) for k in keyvalues), @@ -1474,6 +1507,20 @@ class DatabasePool(object): keyvalues: Dict[str, Any], desc: str, ) -> int: + """Executes a DELETE query on the named table. + + Filters rows by if value of `column` is in `iterable`. + + Args: + table: string giving the table name + column: column name to test for inclusion against `iterable` + iterable: list + keyvalues: dict of column names and values to select the rows with + desc: description of the transaction, for logging and metrics + + Returns: + Number rows deleted + """ return await self.runInteraction( desc, self.simple_delete_many_txn, table, column, iterable, keyvalues ) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 8acf254bf3..6c60171888 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -728,11 +728,13 @@ class GroupServerStore(GroupServerWorkerStore): }, ) - def remove_room_from_summary(self, group_id, room_id, category_id): + async def remove_room_from_summary( + self, group_id: str, room_id: str, category_id: str + ) -> int: if category_id is None: category_id = _DEFAULT_CATEGORY_ID - return self.db_pool.simple_delete( + return await self.db_pool.simple_delete( table="group_summary_rooms", keyvalues={ "group_id": group_id, @@ -772,8 +774,8 @@ class GroupServerStore(GroupServerWorkerStore): desc="upsert_group_category", ) - def remove_group_category(self, group_id, category_id): - return self.db_pool.simple_delete( + async def remove_group_category(self, group_id: str, category_id: str) -> int: + return await self.db_pool.simple_delete( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, desc="remove_group_category", @@ -809,8 +811,8 @@ class GroupServerStore(GroupServerWorkerStore): desc="upsert_group_role", ) - def remove_group_role(self, group_id, role_id): - return self.db_pool.simple_delete( + async def remove_group_role(self, group_id: str, role_id: str) -> int: + return await self.db_pool.simple_delete( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, desc="remove_group_role", @@ -940,11 +942,13 @@ class GroupServerStore(GroupServerWorkerStore): }, ) - def remove_user_from_summary(self, group_id, user_id, role_id): + async def remove_user_from_summary( + self, group_id: str, user_id: str, role_id: str + ) -> int: if role_id is None: role_id = _DEFAULT_ROLE_ID - return self.db_pool.simple_delete( + return await self.db_pool.simple_delete( table="group_summary_users", keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, desc="remove_user_from_summary", @@ -1264,16 +1268,16 @@ class GroupServerStore(GroupServerWorkerStore): desc="update_remote_attestion", ) - def remove_attestation_renewal(self, group_id, user_id): + async def remove_attestation_renewal(self, group_id: str, user_id: str) -> int: """Remove an attestation that we thought we should renew, but actually shouldn't. Ideally this would never get called as we would never incorrectly try and do attestations for local users on local groups. Args: - group_id (str) - user_id (str) + group_id + user_id """ - return self.db_pool.simple_delete( + return await self.db_pool.simple_delete( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, desc="remove_attestation_renewal", diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 28f7ae0430..12689f4308 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -529,21 +529,21 @@ class RegistrationWorkerStore(SQLBaseStore): "user_get_threepids", ) - def user_delete_threepid(self, user_id, medium, address): - return self.db_pool.simple_delete( + async def user_delete_threepid(self, user_id, medium, address) -> None: + await self.db_pool.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, desc="user_delete_threepid", ) - def user_delete_threepids(self, user_id: str): + async def user_delete_threepids(self, user_id: str) -> None: """Delete all threepid this user has bound Args: user_id: The user id to delete all threepids of """ - return self.db_pool.simple_delete( + await self.db_pool.simple_delete( "user_threepids", keyvalues={"user_id": user_id}, desc="user_delete_threepids", @@ -597,21 +597,20 @@ class RegistrationWorkerStore(SQLBaseStore): desc="user_get_bound_threepids", ) - def remove_user_bound_threepid(self, user_id, medium, address, id_server): + async def remove_user_bound_threepid( + self, user_id: str, medium: str, address: str, id_server: str + ) -> None: """The server proxied an unbind request to the given identity server on behalf of the given user, so we remove the mapping of threepid to identity server. Args: - user_id (str) - medium (str) - address (str) - id_server (str) - - Returns: - Deferred + user_id + medium + address + id_server """ - return self.db_pool.simple_delete( + await self.db_pool.simple_delete( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -1247,14 +1246,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="add_user_pending_deactivation", ) - def del_user_pending_deactivation(self, user_id): + async def del_user_pending_deactivation(self, user_id: str) -> None: """ Removes the given user to the table of users who need to be parted from all the rooms they're in, effectively marking that user as fully deactivated. """ # XXX: This should be simple_delete_one but we failed to put a unique index on # the table, so somehow duplicate entries have ended up in it. - return self.db_pool.simple_delete( + await self.db_pool.simple_delete( "users_pending_deactivation", keyvalues={"user_id": user_id}, desc="del_user_pending_deactivation", diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 238bad5b45..0e7427e57a 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -123,8 +123,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store.db_pool.simple_delete( - table="event_push_actions", keyvalues={"1": 1}, desc="" + yield defer.ensureDeferred( + self.store.db_pool.simple_delete( + table="event_push_actions", keyvalues={"1": 1}, desc="" + ) ) yield _assert_counts(1, 0) -- cgit 1.5.1 From 5c03134d0f8dd157ea1800ce1a4bcddbdb73ddf1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Aug 2020 07:54:27 -0400 Subject: Convert additional database code to async/await. (#8195) --- changelog.d/8195.misc | 1 + synapse/appservice/__init__.py | 19 ++- synapse/federation/persistence.py | 19 ++- synapse/handlers/federation.py | 4 +- synapse/storage/databases/main/appservice.py | 15 +- synapse/storage/databases/main/deviceinbox.py | 12 +- synapse/storage/databases/main/e2e_room_keys.py | 30 ++-- synapse/storage/databases/main/event_federation.py | 71 ++++---- synapse/storage/databases/main/group_server.py | 187 +++++++++++++-------- synapse/storage/databases/main/keys.py | 24 +-- synapse/storage/databases/main/transactions.py | 39 +++-- 11 files changed, 246 insertions(+), 175 deletions(-) create mode 100644 changelog.d/8195.misc (limited to 'synapse/storage/databases/main/group_server.py') diff --git a/changelog.d/8195.misc b/changelog.d/8195.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8195.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 1ffdc1ed95..69a7182ef4 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -14,11 +14,16 @@ # limitations under the License. import logging import re +from typing import TYPE_CHECKING from synapse.api.constants import EventTypes +from synapse.appservice.api import ApplicationServiceApi from synapse.types import GroupID, get_domain_from_id from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.storage.databases.main import DataStore + logger = logging.getLogger(__name__) @@ -35,19 +40,19 @@ class AppServiceTransaction(object): self.id = id self.events = events - def send(self, as_api): + async def send(self, as_api: ApplicationServiceApi) -> bool: """Sends this transaction using the provided AS API interface. Args: - as_api(ApplicationServiceApi): The API to use to send. + as_api: The API to use to send. Returns: - An Awaitable which resolves to True if the transaction was sent. + True if the transaction was sent. """ - return as_api.push_bulk( + return await as_api.push_bulk( service=self.service, events=self.events, txn_id=self.id ) - def complete(self, store): + async def complete(self, store: "DataStore") -> None: """Completes this transaction as successful. Marks this transaction ID on the application service and removes the @@ -55,10 +60,8 @@ class AppServiceTransaction(object): Args: store: The database store to operate on. - Returns: - A Deferred which resolves to True if the transaction was completed. """ - return store.complete_appservice_txn(service=self.service, txn_id=self.id) + await store.complete_appservice_txn(service=self.service, txn_id=self.id) class ApplicationService(object): diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 769cd5de28..de1fe7da38 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -20,6 +20,7 @@ These actions are mostly only used by the :py:mod:`.replication` module. """ import logging +from typing import Optional, Tuple from synapse.federation.units import Transaction from synapse.logging.utils import log_function @@ -36,25 +37,27 @@ class TransactionActions(object): self.store = datastore @log_function - def have_responded(self, origin, transaction): - """ Have we already responded to a transaction with the same id and + async def have_responded( + self, origin: str, transaction: Transaction + ) -> Optional[Tuple[int, JsonDict]]: + """Have we already responded to a transaction with the same id and origin? Returns: - Deferred: Results in `None` if we have not previously responded to - this transaction or a 2-tuple of `(int, dict)` representing the - response code and response body. + `None` if we have not previously responded to this transaction or a + 2-tuple of `(int, dict)` representing the response code and response body. """ - if not transaction.transaction_id: + transaction_id = transaction.transaction_id # type: ignore + if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") - return self.store.get_received_txn_response(transaction.transaction_id, origin) + return await self.store.get_received_txn_response(transaction_id, origin) @log_function async def set_response( self, origin: str, transaction: Transaction, code: int, response: JsonDict ) -> None: - """ Persist how we responded to a transaction. + """Persist how we responded to a transaction. """ transaction_id = transaction.transaction_id # type: ignore if not transaction_id: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 155d087413..16389a0dca 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1879,8 +1879,8 @@ class FederationHandler(BaseHandler): else: return None - def get_min_depth_for_context(self, context): - return self.store.get_min_depth(context) + async def get_min_depth_for_context(self, context): + return await self.store.get_min_depth(context) async def _handle_new_event( self, origin, event, state=None, auth_events=None, backfilled=False diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 92f56f1602..454c0bc50c 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -172,7 +172,7 @@ class ApplicationServiceTransactionWorkerStore( "application_services_state", {"as_id": service.id}, {"state": state} ) - def create_appservice_txn(self, service, events): + async def create_appservice_txn(self, service, events): """Atomically creates a new transaction for this application service with the given list of events. @@ -209,20 +209,17 @@ class ApplicationServiceTransactionWorkerStore( ) return AppServiceTransaction(service=service, id=new_txn_id, events=events) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "create_appservice_txn", _create_appservice_txn ) - def complete_appservice_txn(self, txn_id, service): + async def complete_appservice_txn(self, txn_id, service) -> None: """Completes an application service transaction. Args: txn_id(str): The transaction ID being completed. service(ApplicationService): The application service which was sent this transaction. - Returns: - A Deferred which resolves if this transaction was stored - successfully. """ txn_id = int(txn_id) @@ -258,7 +255,7 @@ class ApplicationServiceTransactionWorkerStore( {"txn_id": txn_id, "as_id": service.id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "complete_appservice_txn", _complete_appservice_txn ) @@ -312,13 +309,13 @@ class ApplicationServiceTransactionWorkerStore( else: return int(last_txn_id[0]) # select 'last_txn' col - def set_appservice_last_pos(self, pos): + async def set_appservice_last_pos(self, pos) -> None: def set_appservice_last_pos_txn(txn): txn.execute( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_appservice_last_pos", set_appservice_last_pos_txn ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index bb85637a95..0044433110 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -190,15 +190,15 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) @trace - def delete_device_msgs_for_remote(self, destination, up_to_stream_id): + async def delete_device_msgs_for_remote( + self, destination: str, up_to_stream_id: int + ) -> None: """Used to delete messages when the remote destination acknowledges their receipt. Args: - destination(str): The destination server_name - up_to_stream_id(int): Where to delete messages up to. - Returns: - A deferred that resolves when the messages have been deleted. + destination: The destination server_name + up_to_stream_id: Where to delete messages up to. """ def delete_messages_for_remote_destination_txn(txn): @@ -209,7 +209,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) txn.execute(sql, (destination, up_to_stream_id)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 82f9d870fd..12cecceec2 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -151,7 +151,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): return sessions - def get_e2e_room_keys_multi(self, user_id, version, room_keys): + async def get_e2e_room_keys_multi(self, user_id, version, room_keys): """Get multiple room keys at a time. The difference between this function and get_e2e_room_keys is that this function can be used to retrieve multiple specific keys at a time, whereas get_e2e_room_keys is used for @@ -166,10 +166,10 @@ class EndToEndRoomKeyStore(SQLBaseStore): that we want to query Returns: - Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key + dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_e2e_room_keys_multi", self._get_e2e_room_keys_multi_txn, user_id, @@ -283,7 +283,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): raise StoreError(404, "No current backup version") return row[0] - def get_e2e_room_keys_version_info(self, user_id, version=None): + async def get_e2e_room_keys_version_info(self, user_id, version=None): """Get info metadata about a version of our room_keys backup. Args: @@ -293,7 +293,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): Raises: StoreError: with code 404 if there are no e2e_room_keys_versions present Returns: - A deferred dict giving the info metadata for this backup version, with + A dict giving the info metadata for this backup version, with fields including: version(str) algorithm(str) @@ -324,12 +324,12 @@ class EndToEndRoomKeyStore(SQLBaseStore): result["etag"] = 0 return result - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn ) @trace - def create_e2e_room_keys_version(self, user_id, info): + async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str: """Atomically creates a new version of this user's e2e_room_keys store with the given version info. @@ -338,7 +338,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): info(dict): the info about the backup version to be created Returns: - A deferred string for the newly created version ID + The newly created version ID """ def _create_e2e_room_keys_version_txn(txn): @@ -365,7 +365,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): return new_version - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn ) @@ -403,13 +403,15 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - def delete_e2e_room_keys_version(self, user_id, version=None): + async def delete_e2e_room_keys_version( + self, user_id: str, version: Optional[str] = None + ) -> None: """Delete a given backup version of the user's room keys. Doesn't delete their actual key data. Args: - user_id(str): the user whose backup version we're deleting - version(str): Optional. the version ID of the backup version we're deleting + user_id: the user whose backup version we're deleting + version: Optional. the version ID of the backup version we're deleting If missing, we delete the current backup version info. Raises: StoreError: with code 404 if there are no e2e_room_keys_versions present, @@ -430,13 +432,13 @@ class EndToEndRoomKeyStore(SQLBaseStore): keyvalues={"user_id": user_id, "version": this_version}, ) - return self.db_pool.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version}, updatevalues={"deleted": 1}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 6e5761c7b7..0b69aa6a94 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -59,7 +59,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas include_given: include the given events in result Returns: - list of event_ids + An awaitable which resolve to a list of event_ids """ return await self.db_pool.runInteraction( "get_auth_chain_ids", @@ -95,7 +95,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return list(results) - def get_auth_chain_difference(self, state_sets: List[Set[str]]): + async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -104,10 +104,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas chain. Returns: - Deferred[Set[str]] + The set of the difference in auth chains. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_auth_chain_difference", self._get_auth_chain_difference_txn, state_sets, @@ -252,8 +252,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} - def get_oldest_events_with_depth_in_room(self, room_id): - return self.db_pool.runInteraction( + async def get_oldest_events_with_depth_in_room(self, room_id): + return await self.db_pool.runInteraction( "get_oldest_events_with_depth_in_room", self.get_oldest_events_with_depth_in_room_txn, room_id, @@ -293,7 +293,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas else: return max(row["depth"] for row in rows) - def get_prev_events_for_room(self, room_id: str): + async def get_prev_events_for_room(self, room_id: str) -> List[str]: """ Gets a subset of the current forward extremities in the given room. @@ -301,14 +301,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas events which refer to hundreds of prev_events. Args: - room_id (str): room_id + room_id: room_id Returns: - Deferred[List[str]]: the event ids of the forward extremites + The event ids of the forward extremities. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id ) @@ -328,17 +328,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return [row[0] for row in txn] - def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter): + async def get_rooms_with_many_extremities( + self, min_count: int, limit: int, room_id_filter: Iterable[str] + ) -> List[str]: """Get the top rooms with at least N extremities. Args: - min_count (int): The minimum number of extremities - limit (int): The maximum number of rooms to return. - room_id_filter (iterable[str]): room_ids to exclude from the results + min_count: The minimum number of extremities + limit: The maximum number of rooms to return. + room_id_filter: room_ids to exclude from the results Returns: - Deferred[list]: At most `limit` room IDs that have at least - `min_count` extremities, sorted by extremity count. + At most `limit` room IDs that have at least `min_count` extremities, + sorted by extremity count. """ def _get_rooms_with_many_extremities_txn(txn): @@ -363,7 +365,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(sql, query_args) return [room_id for room_id, in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn ) @@ -376,10 +378,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas desc="get_latest_event_ids_in_room", ) - def get_min_depth(self, room_id): - """ For hte given room, get the minimum depth we have seen for it. + async def get_min_depth(self, room_id: str) -> int: + """For the given room, get the minimum depth we have seen for it. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id ) @@ -394,7 +396,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return int(min_depth) if min_depth is not None else None - def get_forward_extremeties_for_room(self, room_id, stream_ordering): + async def get_forward_extremeties_for_room( + self, room_id: str, stream_ordering: int + ) -> List[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -402,11 +406,11 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas stream_orderings from that point. Args: - room_id (str): - stream_ordering (int): + room_id: + stream_ordering: Returns: - deferred, which resolves to a list of event_ids + A list of event_ids """ # We want to make the cache more effective, so we clamp to the last # change before the given ordering. @@ -422,10 +426,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas if last_change > self.stream_ordering_month_ago: stream_ordering = min(last_change, stream_ordering) - return self._get_forward_extremeties_for_room(room_id, stream_ordering) + return await self._get_forward_extremeties_for_room(room_id, stream_ordering) @cached(max_entries=5000, num_args=2) - def _get_forward_extremeties_for_room(self, room_id, stream_ordering): + async def _get_forward_extremeties_for_room(self, room_id, stream_ordering): """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -450,19 +454,18 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(sql, (stream_ordering, room_id)) return [event_id for event_id, in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - async def get_backfill_events(self, room_id, event_list, limit): + async def get_backfill_events(self, room_id: str, event_list: list, limit: int): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` Args: - txn - room_id (str) - event_list (list) - limit (int) + room_id + event_list + limit """ event_ids = await self.db_pool.runInteraction( "get_backfill_events", @@ -631,8 +634,8 @@ class EventFederationStore(EventFederationWorkerStore): _delete_old_forward_extrem_cache_txn, ) - def clean_room_for_join(self, room_id): - return self.db_pool.runInteraction( + async def clean_room_for_join(self, room_id): + return await self.db_pool.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id ) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 6c60171888..ccfbb2135e 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json @@ -70,7 +70,9 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_invited_users_in_group", ) - def get_rooms_in_group(self, group_id: str, include_private: bool = False): + async def get_rooms_in_group( + self, group_id: str, include_private: bool = False + ) -> List[Dict[str, Union[str, bool]]]: """Retrieve the rooms that belong to a given group. Does not return rooms that lack members. @@ -79,8 +81,7 @@ class GroupServerWorkerStore(SQLBaseStore): include_private: Whether to return private rooms in results Returns: - Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the - form of: + A list of dictionaries, each in the form of: { "room_id": "!a_room_id:example.com", # The ID of the room @@ -117,13 +118,13 @@ class GroupServerWorkerStore(SQLBaseStore): for room_id, is_public in txn ] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_in_group", _get_rooms_in_group_txn ) - def get_rooms_for_summary_by_category( + async def get_rooms_for_summary_by_category( self, group_id: str, include_private: bool = False, - ): + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: """Get the rooms and categories that should be included in a summary request Args: @@ -131,7 +132,7 @@ class GroupServerWorkerStore(SQLBaseStore): include_private: Whether to return private rooms in results Returns: - Deferred[Tuple[List, Dict]]: A tuple containing: + A tuple containing: * A list of dictionaries with the keys: * "room_id": str, the room ID @@ -207,7 +208,7 @@ class GroupServerWorkerStore(SQLBaseStore): return rooms, categories - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_for_summary", _get_rooms_for_summary_txn ) @@ -281,10 +282,11 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_local_groups_for_room", ) - def get_users_for_summary_by_role(self, group_id, include_private=False): + async def get_users_for_summary_by_role(self, group_id, include_private=False): """Get the users and roles that should be included in a summary request - Returns ([users], [roles]) + Returns: + ([users], [roles]) """ def _get_users_for_summary_txn(txn): @@ -338,7 +340,7 @@ class GroupServerWorkerStore(SQLBaseStore): return users, roles - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_for_summary_by_role", _get_users_for_summary_txn ) @@ -376,7 +378,7 @@ class GroupServerWorkerStore(SQLBaseStore): allow_none=True, ) - def get_users_membership_info_in_group(self, group_id, user_id): + async def get_users_membership_info_in_group(self, group_id, user_id): """Get a dict describing the membership of a user in a group. Example if joined: @@ -387,7 +389,8 @@ class GroupServerWorkerStore(SQLBaseStore): "is_privileged": False, } - Returns an empty dict if the user is not join/invite/etc + Returns: + An empty dict if the user is not join/invite/etc """ def _get_users_membership_in_group_txn(txn): @@ -419,7 +422,7 @@ class GroupServerWorkerStore(SQLBaseStore): return {} - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_membership_info_in_group", _get_users_membership_in_group_txn ) @@ -433,7 +436,7 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_publicised_groups_for_user", ) - def get_attestations_need_renewals(self, valid_until_ms): + async def get_attestations_need_renewals(self, valid_until_ms): """Get all attestations that need to be renewed until givent time """ @@ -445,7 +448,7 @@ class GroupServerWorkerStore(SQLBaseStore): txn.execute(sql, (valid_until_ms,)) return self.db_pool.cursor_to_dict(txn) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) @@ -475,7 +478,7 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_joined_groups", ) - def get_all_groups_for_user(self, user_id, now_token): + async def get_all_groups_for_user(self, user_id, now_token): def _get_all_groups_for_user_txn(txn): sql = """ SELECT group_id, type, membership, u.content @@ -495,7 +498,7 @@ class GroupServerWorkerStore(SQLBaseStore): for row in txn ] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_all_groups_for_user", _get_all_groups_for_user_txn ) @@ -600,8 +603,27 @@ class GroupServerStore(GroupServerWorkerStore): desc="set_group_join_policy", ) - def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): - return self.db_pool.runInteraction( + async def add_room_to_summary( + self, + group_id: str, + room_id: str, + category_id: str, + order: int, + is_public: Optional[bool], + ) -> None: + """Add (or update) room's entry in summary. + + Args: + group_id + room_id + category_id: If not None then adds the category to the end of + the summary if its not already there. + order: If not None inserts the room at that position, e.g. an order + of 1 will put the room first. Otherwise, the room gets added to + the end. + is_public + """ + await self.db_pool.runInteraction( "add_room_to_summary", self._add_room_to_summary_txn, group_id, @@ -612,18 +634,26 @@ class GroupServerStore(GroupServerWorkerStore): ) def _add_room_to_summary_txn( - self, txn, group_id, room_id, category_id, order, is_public - ): + self, + txn, + group_id: str, + room_id: str, + category_id: str, + order: int, + is_public: Optional[bool], + ) -> None: """Add (or update) room's entry in summary. Args: - group_id (str) - room_id (str) - category_id (str): If not None then adds the category to the end of - the summary if its not already there. [Optional] - order (int): If not None inserts the room at that position, e.g. - an order of 1 will put the room first. Otherwise, the room gets - added to the end. + txn + group_id + room_id + category_id: If not None then adds the category to the end of + the summary if its not already there. + order: If not None inserts the room at that position, e.g. an order + of 1 will put the room first. Otherwise, the room gets added to + the end. + is_public """ room_in_group = self.db_pool.simple_select_one_onecol_txn( txn, @@ -818,8 +848,27 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_group_role", ) - def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): - return self.db_pool.runInteraction( + async def add_user_to_summary( + self, + group_id: str, + user_id: str, + role_id: str, + order: int, + is_public: Optional[bool], + ) -> None: + """Add (or update) user's entry in summary. + + Args: + group_id + user_id + role_id: If not None then adds the role to the end of the summary if + its not already there. + order: If not None inserts the user at that position, e.g. an order + of 1 will put the user first. Otherwise, the user gets added to + the end. + is_public + """ + await self.db_pool.runInteraction( "add_user_to_summary", self._add_user_to_summary_txn, group_id, @@ -830,18 +879,26 @@ class GroupServerStore(GroupServerWorkerStore): ) def _add_user_to_summary_txn( - self, txn, group_id, user_id, role_id, order, is_public + self, + txn, + group_id: str, + user_id: str, + role_id: str, + order: int, + is_public: Optional[bool], ): """Add (or update) user's entry in summary. Args: - group_id (str) - user_id (str) - role_id (str): If not None then adds the role to the end of - the summary if its not already there. [Optional] - order (int): If not None inserts the user at that position, e.g. - an order of 1 will put the user first. Otherwise, the user gets - added to the end. + txn + group_id + user_id + role_id: If not None then adds the role to the end of the summary if + its not already there. + order: If not None inserts the user at that position, e.g. an order + of 1 will put the user first. Otherwise, the user gets added to + the end. + is_public """ user_in_group = self.db_pool.simple_select_one_onecol_txn( txn, @@ -963,27 +1020,26 @@ class GroupServerStore(GroupServerWorkerStore): desc="add_group_invite", ) - def add_user_to_group( + async def add_user_to_group( self, - group_id, - user_id, - is_admin=False, - is_public=True, - local_attestation=None, - remote_attestation=None, - ): + group_id: str, + user_id: str, + is_admin: bool = False, + is_public: bool = True, + local_attestation: dict = None, + remote_attestation: dict = None, + ) -> None: """Add a user to the group server. Args: - group_id (str) - user_id (str) - is_admin (bool) - is_public (bool) - local_attestation (dict): The attestation the GS created to give - to the remote server. Optional if the user and group are on the - same server - remote_attestation (dict): The attestation given to GS by remote + group_id + user_id + is_admin + is_public + local_attestation: The attestation the GS created to give to the remote server. Optional if the user and group are on the same server + remote_attestation: The attestation given to GS by remote server. + Optional if the user and group are on the same server """ def _add_user_to_group_txn(txn): @@ -1026,9 +1082,9 @@ class GroupServerStore(GroupServerWorkerStore): }, ) - return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) + await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) - def remove_user_from_group(self, group_id, user_id): + async def remove_user_from_group(self, group_id: str, user_id: str) -> None: def _remove_user_from_group_txn(txn): self.db_pool.simple_delete_txn( txn, @@ -1056,7 +1112,7 @@ class GroupServerStore(GroupServerWorkerStore): keyvalues={"group_id": group_id, "user_id": user_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "remove_user_from_group", _remove_user_from_group_txn ) @@ -1079,7 +1135,7 @@ class GroupServerStore(GroupServerWorkerStore): desc="update_room_in_group_visibility", ) - def remove_room_from_group(self, group_id, room_id): + async def remove_room_from_group(self, group_id: str, room_id: str) -> None: def _remove_room_from_group_txn(txn): self.db_pool.simple_delete_txn( txn, @@ -1093,7 +1149,7 @@ class GroupServerStore(GroupServerWorkerStore): keyvalues={"group_id": group_id, "room_id": room_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "remove_room_from_group", _remove_room_from_group_txn ) @@ -1286,14 +1342,11 @@ class GroupServerStore(GroupServerWorkerStore): def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() - def delete_group(self, group_id): + async def delete_group(self, group_id: str) -> None: """Deletes a group fully from the database. Args: - group_id (str) - - Returns: - Deferred + group_id: The group ID to delete. """ def _delete_group_txn(txn): @@ -1317,4 +1370,4 @@ class GroupServerStore(GroupServerWorkerStore): txn, table=table, keyvalues={"group_id": group_id} ) - return self.db_pool.runInteraction("delete_group", _delete_group_txn) + await self.db_pool.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 1c0a049c55..ad43bb05ab 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,7 +16,7 @@ import itertools import logging -from typing import Iterable, Tuple +from typing import Dict, Iterable, List, Optional, Tuple from signedjson.key import decode_verify_key_bytes @@ -42,16 +42,17 @@ class KeyStore(SQLBaseStore): @cachedList( cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" ) - def get_server_verify_keys(self, server_name_and_key_ids): + async def get_server_verify_keys( + self, server_name_and_key_ids: Iterable[Tuple[str, str]] + ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]: """ Args: - server_name_and_key_ids (iterable[Tuple[str, str]]): + server_name_and_key_ids: iterable of (server_name, key-id) tuples to fetch keys for Returns: - Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: - map from (server_name, key_id) -> FetchKeyResult, or None if the key is - unknown + A map from (server_name, key_id) -> FetchKeyResult, or None if the + key is unknown """ keys = {} @@ -87,7 +88,7 @@ class KeyStore(SQLBaseStore): _get_keys(txn, batch) return keys - return self.db_pool.runInteraction("get_server_verify_keys", _txn) + return await self.db_pool.runInteraction("get_server_verify_keys", _txn) async def store_server_verify_keys( self, @@ -179,7 +180,9 @@ class KeyStore(SQLBaseStore): desc="store_server_keys_json", ) - def get_server_keys_json(self, server_keys): + async def get_server_keys_json( + self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] + ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]: """Retrive the key json for a list of server_keys and key ids. If no keys are found for a given server, key_id and source then that server, key_id, and source triplet entry will be an empty list. @@ -188,8 +191,7 @@ class KeyStore(SQLBaseStore): Args: server_keys (list): List of (server_name, key_id, source) triplets. Returns: - Deferred[dict[Tuple[str, str, str|None], list[dict]]]: - Dict mapping (server_name, key_id, source) triplets to lists of dicts + A mapping from (server_name, key_id, source) triplets to a list of dicts """ def _get_server_keys_json_txn(txn): @@ -215,6 +217,6 @@ class KeyStore(SQLBaseStore): results[(server_name, key_id, from_server)] = rows return results - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_server_keys_json", _get_server_keys_json_txn ) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 2efcc0dc66..5b31aab700 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from typing import Optional, Tuple from canonicaljson import encode_canonical_json @@ -56,21 +57,23 @@ class TransactionStore(SQLBaseStore): expiry_ms=5 * 60 * 1000, ) - def get_received_txn_response(self, transaction_id, origin): + async def get_received_txn_response( + self, transaction_id: str, origin: str + ) -> Optional[Tuple[int, JsonDict]]: """For an incoming transaction from a given origin, check if we have already responded to it. If so, return the response code and response body (as a dict). Args: - transaction_id (str) - origin(str) + transaction_id + origin Returns: - tuple: None if we have not previously responded to - this transaction or a 2-tuple of (int, dict) + None if we have not previously responded to this transaction or a + 2-tuple of (int, dict) """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_received_txn_response", self._get_received_txn_response, transaction_id, @@ -166,21 +169,25 @@ class TransactionStore(SQLBaseStore): else: return None - def set_destination_retry_timings( - self, destination, failure_ts, retry_last_ts, retry_interval - ): + async def set_destination_retry_timings( + self, + destination: str, + failure_ts: Optional[int], + retry_last_ts: int, + retry_interval: int, + ) -> None: """Sets the current retry timings for a given destination. Both timings should be zero if retrying is no longer occuring. Args: - destination (str) - failure_ts (int|None) - when the server started failing (ms since epoch) - retry_last_ts (int) - time of last retry attempt in unix epoch ms - retry_interval (int) - how long until next retry in ms + destination + failure_ts: when the server started failing (ms since epoch) + retry_last_ts: time of last retry attempt in unix epoch ms + retry_interval: how long until next retry in ms """ self._destination_retry_cache.pop(destination, None) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, destination, @@ -256,13 +263,13 @@ class TransactionStore(SQLBaseStore): "cleanup_transactions", self._cleanup_transactions ) - def _cleanup_transactions(self): + async def _cleanup_transactions(self) -> None: now = self._clock.time_msec() month_ago = now - 30 * 24 * 60 * 60 * 1000 def _cleanup_transactions_txn(txn): txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_cleanup_transactions", _cleanup_transactions_txn ) -- cgit 1.5.1