summary refs log tree commit diff
path: root/synapse/storage/databases/main/group_server.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/group_server.py')
-rw-r--r--synapse/storage/databases/main/group_server.py156
1 files changed, 94 insertions, 62 deletions
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py

index 3f6086050b..0aef121d83 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py
@@ -13,13 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast from typing_extensions import TypedDict from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.types import JsonDict from synapse.util import json_encoder @@ -75,7 +79,7 @@ class GroupServerWorkerStore(SQLBaseStore): ) -> List[Dict[str, Any]]: # TODO: Pagination - keyvalues = {"group_id": group_id} + keyvalues: JsonDict = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -117,7 +121,7 @@ class GroupServerWorkerStore(SQLBaseStore): # TODO: Pagination - def _get_rooms_in_group_txn(txn): + def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]: sql = """ SELECT room_id, is_public FROM group_rooms WHERE group_id = ? @@ -176,8 +180,10 @@ class GroupServerWorkerStore(SQLBaseStore): * "order": int, the sort order of rooms in this category """ - def _get_rooms_for_summary_txn(txn): - keyvalues = {"group_id": group_id} + def _get_rooms_for_summary_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + keyvalues: JsonDict = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -241,7 +247,7 @@ class GroupServerWorkerStore(SQLBaseStore): "get_rooms_for_summary", _get_rooms_for_summary_txn ) - async def get_group_categories(self, group_id): + async def get_group_categories(self, group_id: str) -> JsonDict: rows = await self.db_pool.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, @@ -257,7 +263,7 @@ class GroupServerWorkerStore(SQLBaseStore): for row in rows } - async def get_group_category(self, group_id, category_id): + async def get_group_category(self, group_id: str, category_id: str) -> JsonDict: category = await self.db_pool.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -269,7 +275,7 @@ class GroupServerWorkerStore(SQLBaseStore): return category - async def get_group_roles(self, group_id): + async def get_group_roles(self, group_id: str) -> JsonDict: rows = await self.db_pool.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, @@ -285,7 +291,7 @@ class GroupServerWorkerStore(SQLBaseStore): for row in rows } - async def get_group_role(self, group_id, role_id): + async def get_group_role(self, group_id: str, role_id: str) -> JsonDict: role = await self.db_pool.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -311,15 +317,19 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_local_groups_for_room", ) - async def get_users_for_summary_by_role(self, group_id, include_private=False): + async def get_users_for_summary_by_role( + self, group_id: str, include_private: bool = False + ) -> Tuple[List[JsonDict], JsonDict]: """Get the users and roles that should be included in a summary request Returns: ([users], [roles]) """ - def _get_users_for_summary_txn(txn): - keyvalues = {"group_id": group_id} + def _get_users_for_summary_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], JsonDict]: + keyvalues: JsonDict = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -406,7 +416,9 @@ class GroupServerWorkerStore(SQLBaseStore): allow_none=True, ) - async def get_users_membership_info_in_group(self, group_id, user_id): + async def get_users_membership_info_in_group( + self, group_id: str, user_id: str + ) -> JsonDict: """Get a dict describing the membership of a user in a group. Example if joined: @@ -421,7 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore): An empty dict if the user is not join/invite/etc """ - def _get_users_membership_in_group_txn(txn): + def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict: row = self.db_pool.simple_select_one_txn( txn, table="group_users", @@ -463,10 +475,14 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_publicised_groups_for_user", ) - async def get_attestations_need_renewals(self, valid_until_ms): + async def get_attestations_need_renewals( + self, valid_until_ms: int + ) -> List[Dict[str, Any]]: """Get all attestations that need to be renewed until givent time""" - def _get_attestations_need_renewals_txn(txn): + def _get_attestations_need_renewals_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, Any]]: sql = """ SELECT group_id, user_id FROM group_attestations_renewals WHERE valid_until_ms <= ? @@ -478,7 +494,9 @@ class GroupServerWorkerStore(SQLBaseStore): "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) - async def get_remote_attestation(self, group_id, user_id): + async def get_remote_attestation( + self, group_id: str, user_id: str + ) -> Optional[JsonDict]: """Get the attestation that proves the remote agrees that the user is in the group. """ @@ -504,8 +522,8 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_joined_groups", ) - async def get_all_groups_for_user(self, user_id, now_token): - def _get_all_groups_for_user_txn(txn): + async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]: + def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: sql = """ SELECT group_id, type, membership, u.content FROM local_group_updates AS u @@ -528,15 +546,16 @@ class GroupServerWorkerStore(SQLBaseStore): "get_all_groups_for_user", _get_all_groups_for_user_txn ) - async def get_groups_changes_for_user(self, user_id, from_token, to_token): - from_token = int(from_token) - has_changed = self._group_updates_stream_cache.has_entity_changed( + async def get_groups_changes_for_user( + self, user_id: str, from_token: int, to_token: int + ) -> List[JsonDict]: + has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined] user_id, from_token ) if not has_changed: return [] - def _get_groups_changes_for_user_txn(txn): + def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: sql = """ SELECT group_id, membership, type, u.content FROM local_group_updates AS u @@ -583,12 +602,14 @@ class GroupServerWorkerStore(SQLBaseStore): """ last_id = int(last_id) - has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) + has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined] if not has_changed: return [], current_id, False - def _get_all_groups_changes_txn(txn): + def _get_all_groups_changes_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: sql = """ SELECT stream_id, group_id, user_id, type, content FROM local_group_updates @@ -596,10 +617,13 @@ class GroupServerWorkerStore(SQLBaseStore): LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - updates = [ - (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) - for stream_id, group_id, user_id, gtype, content_json in txn - ] + updates = cast( + List[Tuple[int, tuple]], + [ + (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) + for stream_id, group_id, user_id, gtype, content_json in txn + ], + ) limited = False upto_token = current_id @@ -633,8 +657,8 @@ class GroupServerStore(GroupServerWorkerStore): self, group_id: str, room_id: str, - category_id: str, - order: int, + category_id: Optional[str], + order: Optional[int], is_public: Optional[bool], ) -> None: """Add (or update) room's entry in summary. @@ -661,11 +685,11 @@ class GroupServerStore(GroupServerWorkerStore): def _add_room_to_summary_txn( self, - txn, + txn: LoggingTransaction, group_id: str, room_id: str, - category_id: str, - order: int, + category_id: Optional[str], + order: Optional[int], is_public: Optional[bool], ) -> None: """Add (or update) room's entry in summary. @@ -750,7 +774,7 @@ class GroupServerStore(GroupServerWorkerStore): WHERE group_id = ? AND category_id = ? """ txn.execute(sql, (group_id, category_id)) - (order,) = txn.fetchone() + (order,) = cast(Tuple[int], txn.fetchone()) if existing: to_update = {} @@ -766,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore): "category_id": category_id, "room_id": room_id, }, - values=to_update, + updatevalues=to_update, ) else: if is_public is None: @@ -785,7 +809,7 @@ class GroupServerStore(GroupServerWorkerStore): ) async def remove_room_from_summary( - self, group_id: str, room_id: str, category_id: str + self, group_id: str, room_id: str, category_id: Optional[str] ) -> int: if category_id is None: category_id = _DEFAULT_CATEGORY_ID @@ -808,8 +832,8 @@ class GroupServerStore(GroupServerWorkerStore): is_public: Optional[bool], ) -> None: """Add/update room category for group""" - insertion_values = {} - update_values = {"category_id": category_id} # This cannot be empty + insertion_values: JsonDict = {} + update_values: JsonDict = {"category_id": category_id} # This cannot be empty if profile is None: insertion_values["profile"] = "{}" @@ -844,8 +868,8 @@ class GroupServerStore(GroupServerWorkerStore): is_public: Optional[bool], ) -> None: """Add/remove user role""" - insertion_values = {} - update_values = {"role_id": role_id} # This cannot be empty + insertion_values: JsonDict = {} + update_values: JsonDict = {"role_id": role_id} # This cannot be empty if profile is None: insertion_values["profile"] = "{}" @@ -876,8 +900,8 @@ class GroupServerStore(GroupServerWorkerStore): self, group_id: str, user_id: str, - role_id: str, - order: int, + role_id: Optional[str], + order: Optional[int], is_public: Optional[bool], ) -> None: """Add (or update) user's entry in summary. @@ -904,13 +928,13 @@ class GroupServerStore(GroupServerWorkerStore): def _add_user_to_summary_txn( self, - txn, + txn: LoggingTransaction, group_id: str, user_id: str, - role_id: str, - order: int, + role_id: Optional[str], + order: Optional[int], is_public: Optional[bool], - ): + ) -> None: """Add (or update) user's entry in summary. Args: @@ -989,7 +1013,7 @@ class GroupServerStore(GroupServerWorkerStore): WHERE group_id = ? AND role_id = ? """ txn.execute(sql, (group_id, role_id)) - (order,) = txn.fetchone() + (order,) = cast(Tuple[int], txn.fetchone()) if existing: to_update = {} @@ -1005,7 +1029,7 @@ class GroupServerStore(GroupServerWorkerStore): "role_id": role_id, "user_id": user_id, }, - values=to_update, + updatevalues=to_update, ) else: if is_public is None: @@ -1024,7 +1048,7 @@ class GroupServerStore(GroupServerWorkerStore): ) async def remove_user_from_summary( - self, group_id: str, user_id: str, role_id: str + self, group_id: str, user_id: str, role_id: Optional[str] ) -> int: if role_id is None: role_id = _DEFAULT_ROLE_ID @@ -1065,7 +1089,7 @@ class GroupServerStore(GroupServerWorkerStore): Optional if the user and group are on the same server """ - def _add_user_to_group_txn(txn): + def _add_user_to_group_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_insert_txn( txn, table="group_users", @@ -1108,7 +1132,7 @@ class GroupServerStore(GroupServerWorkerStore): await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) async def remove_user_from_group(self, group_id: str, user_id: str) -> None: - def _remove_user_from_group_txn(txn): + def _remove_user_from_group_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, table="group_users", @@ -1159,7 +1183,7 @@ class GroupServerStore(GroupServerWorkerStore): ) async def remove_room_from_group(self, group_id: str, room_id: str) -> None: - def _remove_room_from_group_txn(txn): + def _remove_room_from_group_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, table="group_rooms", @@ -1216,7 +1240,9 @@ class GroupServerStore(GroupServerWorkerStore): content = content or {} - def _register_user_group_membership_txn(txn, next_id): + def _register_user_group_membership_txn( + txn: LoggingTransaction, next_id: int + ) -> int: # TODO: Upsert? self.db_pool.simple_delete_txn( txn, @@ -1249,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore): ), }, ) - self._group_updates_stream_cache.entity_has_changed(user_id, next_id) + self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined] # TODO: Insert profile to ensure it comes down stream if its a join. @@ -1289,7 +1315,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - async with self._group_updates_id_gen.get_next() as next_id: + async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined] res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, @@ -1298,7 +1324,13 @@ class GroupServerStore(GroupServerWorkerStore): return res async def create_group( - self, group_id, user_id, name, avatar_url, short_description, long_description + self, + group_id: str, + user_id: str, + name: str, + avatar_url: str, + short_description: str, + long_description: str, ) -> None: await self.db_pool.simple_insert( table="groups", @@ -1313,7 +1345,7 @@ class GroupServerStore(GroupServerWorkerStore): desc="create_group", ) - async def update_group_profile(self, group_id, profile): + async def update_group_profile(self, group_id: str, profile: JsonDict) -> None: await self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, @@ -1361,8 +1393,8 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_attestation_renewal", ) - def get_group_stream_token(self): - return self._group_updates_id_gen.get_current_token() + def get_group_stream_token(self) -> int: + return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined] async def delete_group(self, group_id: str) -> None: """Deletes a group fully from the database. @@ -1371,7 +1403,7 @@ class GroupServerStore(GroupServerWorkerStore): group_id: The group ID to delete. """ - def _delete_group_txn(txn): + def _delete_group_txn(txn: LoggingTransaction) -> None: tables = [ "groups", "group_users",