diff options
author | Erik Johnston <erik@matrix.org> | 2022-03-29 10:15:25 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2022-03-29 10:15:25 +0100 |
commit | fd1b6334f071c7bbd88928513e99e5c8971ac414 (patch) | |
tree | 7333cc7290baba787047b22d0e5bdb875cd8e59b /synapse/storage/databases | |
parent | Merge branch 'release-v1.55' of github.com:matrix-org/synapse into matrix-org... (diff) | |
parent | Remove unused `auth_event_ids` argument plumbing (#12304) (diff) | |
download | synapse-fd1b6334f071c7bbd88928513e99e5c8971ac414.tar.xz |
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r-- | synapse/storage/databases/main/account_data.py | 41 | ||||
-rw-r--r-- | synapse/storage/databases/main/cache.py | 61 | ||||
-rw-r--r-- | synapse/storage/databases/main/event_federation.py | 12 | ||||
-rw-r--r-- | synapse/storage/databases/main/events.py | 7 | ||||
-rw-r--r-- | synapse/storage/databases/main/group_server.py | 156 | ||||
-rw-r--r-- | synapse/storage/databases/main/media_repository.py | 2 | ||||
-rw-r--r-- | synapse/storage/databases/main/monthly_active_users.py | 38 | ||||
-rw-r--r-- | synapse/storage/databases/main/receipts.py | 37 | ||||
-rw-r--r-- | synapse/storage/databases/main/registration.py | 3 | ||||
-rw-r--r-- | synapse/storage/databases/main/relations.py | 149 | ||||
-rw-r--r-- | synapse/storage/databases/main/room.py | 3 | ||||
-rw-r--r-- | synapse/storage/databases/main/roommember.py | 37 | ||||
-rw-r--r-- | synapse/storage/databases/main/search.py | 26 | ||||
-rw-r--r-- | synapse/storage/databases/main/state.py | 24 | ||||
-rw-r--r-- | synapse/storage/databases/main/stats.py | 2 | ||||
-rw-r--r-- | synapse/storage/databases/main/user_directory.py | 31 |
16 files changed, 330 insertions, 299 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 52146aacc8..9af9f4f18e 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -14,7 +14,17 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Tuple, + cast, +) from synapse.api.constants import AccountDataTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker @@ -365,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) @cached(max_entries=5000, iterable=True) - async def ignored_by(self, user_id: str) -> Set[str]: + async def ignored_by(self, user_id: str) -> FrozenSet[str]: """ Get users which ignore the given user. @@ -375,7 +385,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) Return: The user IDs which ignore the given user. """ - return set( + return frozenset( await self.db_pool.simple_select_onecol( table="ignored_users", keyvalues={"ignored_user_id": user_id}, @@ -384,6 +394,26 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) ) + @cached(max_entries=5000, iterable=True) + async def ignored_users(self, user_id: str) -> FrozenSet[str]: + """ + Get users which the given user ignores. + + Params: + user_id: The user ID which is making the request. + + Return: + The user IDs which are ignored by the given user. + """ + return frozenset( + await self.db_pool.simple_select_onecol( + table="ignored_users", + keyvalues={"ignorer_user_id": user_id}, + retcol="ignored_user_id", + desc="ignored_users", + ) + ) + def process_replication_rows( self, stream_name: str, @@ -529,6 +559,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) else: currently_ignored_users = set() + # If the data has not changed, nothing to do. + if previously_ignored_users == currently_ignored_users: + return + # Delete entries which are no longer ignored. self.db_pool.simple_delete_many_txn( txn, @@ -551,6 +585,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # Invalidate the cache for any ignored users which were added or removed. for ignored_user_id in previously_ignored_users ^ currently_ignored_users: self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) + self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) async def purge_account_data_for_user(self, user_id: str) -> None: """ diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index d6a2df1afe..dd4e83a2ad 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamCurrentStateRow, EventsStreamEventRow, + EventsStreamRow, ) from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -31,6 +32,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine +from synapse.util.caches.descriptors import _CachedFunction from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_updated_caches_txn(txn): + def get_all_updated_caches_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: # We purposefully don't bound by the current token, as we want to # send across cache invalidations as quickly as possible. Cache # invalidations are idempotent, so duplicates are fine. @@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): "get_all_updated_caches", get_all_updated_caches_txn ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: if stream_name == EventsStream.NAME: for row in rows: self._process_event_stream_row(token, row) @@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) - def _process_event_stream_row(self, token, row): + def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data = row.data if row.type == EventsStreamEventRow.TypeId: + assert isinstance(data, EventsStreamEventRow) self._invalidate_caches_for_event( token, data.event_id, @@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: - self._curr_state_delta_stream_cache.entity_has_changed( - row.data.room_id, token - ) + assert isinstance(data, EventsStreamCurrentStateRow) + self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) if data.type == EventTypes.Member: self.get_rooms_for_user_with_stream_ordering.invalidate( @@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore): def _invalidate_caches_for_event( self, - stream_ordering, - event_id, - room_id, - etype, - state_key, - redacts, - relates_to, - backfilled, - ): + stream_ordering: int, + event_id: str, + room_id: str, + etype: str, + state_key: Optional[str], + redacts: Optional[str], + relates_to: Optional[str], + backfilled: bool, + ) -> None: self._invalidate_get_event_cache(event_id) self.have_seen_event.invalidate((room_id, event_id)) @@ -186,6 +192,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,)) + # The `_get_membership_from_event_id` is immutable, except for the + # case where we look up an event *before* persisting it. + self._get_membership_from_event_id.invalidate((event_id,)) + if not backfilled: self._events_stream_cache.entity_has_changed(room_id, stream_ordering) @@ -207,7 +217,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_thread_summary.invalidate((relates_to,)) self.get_thread_participated.invalidate((relates_to,)) - async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): + async def invalidate_cache_and_stream( + self, cache_name: str, keys: Tuple[Any, ...] + ) -> None: """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -227,7 +239,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore): keys, ) - def _invalidate_cache_and_stream(self, txn, cache_func, keys): + def _invalidate_cache_and_stream( + self, + txn: LoggingTransaction, + cache_func: _CachedFunction, + keys: Tuple[Any, ...], + ) -> None: """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -238,7 +255,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): txn.call_after(cache_func.invalidate, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys) - def _invalidate_all_cache_and_stream(self, txn, cache_func): + def _invalidate_all_cache_and_stream( + self, txn: LoggingTransaction, cache_func: _CachedFunction + ) -> None: """Invalidates the entire cache and adds it to the cache stream so slaves will know to invalidate their caches. """ @@ -279,8 +298,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ) def _send_invalidation_to_replication( - self, txn, cache_name: str, keys: Optional[Iterable[Any]] - ): + self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] + ) -> None: """Notifies replication that given cache has been invalidated. Note that this does *not* invalidate the cache locally. @@ -315,7 +334,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): "instance_name": self._instance_name, "cache_func": cache_name, "keys": keys, - "invalidation_ts": self.clock.time_msec(), + "invalidation_ts": self._clock.time_msec(), }, ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 277e6422eb..634e19e035 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1073,9 +1073,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas /* Get the depth and stream_ordering of the prev_event_id from the events table */ INNER JOIN events ON prev_event_id = events.event_id + + /* exclude outliers from the results (we don't have the state, so cannot + * verify if the requesting server can see them). + */ + WHERE NOT events.outlier + /* Look for an edge which matches the given event_id */ - WHERE event_edges.event_id = ? - AND event_edges.is_state = ? + AND event_edges.event_id = ? AND NOT event_edges.is_state + /* Because we can have many events at the same depth, * we want to also tie-break and sort on stream_ordering */ ORDER BY depth DESC, stream_ordering DESC @@ -1084,7 +1090,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas txn.execute( connected_prev_event_query, - (event_id, False, limit), + (event_id, limit), ) return [ BackfillQueueNavigationItem( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1f60aef180..d253243125 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1745,6 +1745,13 @@ class PersistEventsStore: (event.state_key,), ) + # The `_get_membership_from_event_id` is immutable, except for the + # case where we look up an event *before* persisting it. + txn.call_after( + self.store._get_membership_from_event_id.invalidate, + (event.event_id,), + ) + # We update the local_current_membership table only if the event is # "current", i.e., its something that has just happened. # diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 3f6086050b..0aef121d83 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -13,13 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast from typing_extensions import TypedDict from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.types import JsonDict from synapse.util import json_encoder @@ -75,7 +79,7 @@ class GroupServerWorkerStore(SQLBaseStore): ) -> List[Dict[str, Any]]: # TODO: Pagination - keyvalues = {"group_id": group_id} + keyvalues: JsonDict = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -117,7 +121,7 @@ class GroupServerWorkerStore(SQLBaseStore): # TODO: Pagination - def _get_rooms_in_group_txn(txn): + def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]: sql = """ SELECT room_id, is_public FROM group_rooms WHERE group_id = ? @@ -176,8 +180,10 @@ class GroupServerWorkerStore(SQLBaseStore): * "order": int, the sort order of rooms in this category """ - def _get_rooms_for_summary_txn(txn): - keyvalues = {"group_id": group_id} + def _get_rooms_for_summary_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + keyvalues: JsonDict = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -241,7 +247,7 @@ class GroupServerWorkerStore(SQLBaseStore): "get_rooms_for_summary", _get_rooms_for_summary_txn ) - async def get_group_categories(self, group_id): + async def get_group_categories(self, group_id: str) -> JsonDict: rows = await self.db_pool.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, @@ -257,7 +263,7 @@ class GroupServerWorkerStore(SQLBaseStore): for row in rows } - async def get_group_category(self, group_id, category_id): + async def get_group_category(self, group_id: str, category_id: str) -> JsonDict: category = await self.db_pool.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -269,7 +275,7 @@ class GroupServerWorkerStore(SQLBaseStore): return category - async def get_group_roles(self, group_id): + async def get_group_roles(self, group_id: str) -> JsonDict: rows = await self.db_pool.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, @@ -285,7 +291,7 @@ class GroupServerWorkerStore(SQLBaseStore): for row in rows } - async def get_group_role(self, group_id, role_id): + async def get_group_role(self, group_id: str, role_id: str) -> JsonDict: role = await self.db_pool.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -311,15 +317,19 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_local_groups_for_room", ) - async def get_users_for_summary_by_role(self, group_id, include_private=False): + async def get_users_for_summary_by_role( + self, group_id: str, include_private: bool = False + ) -> Tuple[List[JsonDict], JsonDict]: """Get the users and roles that should be included in a summary request Returns: ([users], [roles]) """ - def _get_users_for_summary_txn(txn): - keyvalues = {"group_id": group_id} + def _get_users_for_summary_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], JsonDict]: + keyvalues: JsonDict = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True @@ -406,7 +416,9 @@ class GroupServerWorkerStore(SQLBaseStore): allow_none=True, ) - async def get_users_membership_info_in_group(self, group_id, user_id): + async def get_users_membership_info_in_group( + self, group_id: str, user_id: str + ) -> JsonDict: """Get a dict describing the membership of a user in a group. Example if joined: @@ -421,7 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore): An empty dict if the user is not join/invite/etc """ - def _get_users_membership_in_group_txn(txn): + def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict: row = self.db_pool.simple_select_one_txn( txn, table="group_users", @@ -463,10 +475,14 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_publicised_groups_for_user", ) - async def get_attestations_need_renewals(self, valid_until_ms): + async def get_attestations_need_renewals( + self, valid_until_ms: int + ) -> List[Dict[str, Any]]: """Get all attestations that need to be renewed until givent time""" - def _get_attestations_need_renewals_txn(txn): + def _get_attestations_need_renewals_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, Any]]: sql = """ SELECT group_id, user_id FROM group_attestations_renewals WHERE valid_until_ms <= ? @@ -478,7 +494,9 @@ class GroupServerWorkerStore(SQLBaseStore): "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) - async def get_remote_attestation(self, group_id, user_id): + async def get_remote_attestation( + self, group_id: str, user_id: str + ) -> Optional[JsonDict]: """Get the attestation that proves the remote agrees that the user is in the group. """ @@ -504,8 +522,8 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_joined_groups", ) - async def get_all_groups_for_user(self, user_id, now_token): - def _get_all_groups_for_user_txn(txn): + async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]: + def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: sql = """ SELECT group_id, type, membership, u.content FROM local_group_updates AS u @@ -528,15 +546,16 @@ class GroupServerWorkerStore(SQLBaseStore): "get_all_groups_for_user", _get_all_groups_for_user_txn ) - async def get_groups_changes_for_user(self, user_id, from_token, to_token): - from_token = int(from_token) - has_changed = self._group_updates_stream_cache.has_entity_changed( + async def get_groups_changes_for_user( + self, user_id: str, from_token: int, to_token: int + ) -> List[JsonDict]: + has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined] user_id, from_token ) if not has_changed: return [] - def _get_groups_changes_for_user_txn(txn): + def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: sql = """ SELECT group_id, membership, type, u.content FROM local_group_updates AS u @@ -583,12 +602,14 @@ class GroupServerWorkerStore(SQLBaseStore): """ last_id = int(last_id) - has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) + has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined] if not has_changed: return [], current_id, False - def _get_all_groups_changes_txn(txn): + def _get_all_groups_changes_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: sql = """ SELECT stream_id, group_id, user_id, type, content FROM local_group_updates @@ -596,10 +617,13 @@ class GroupServerWorkerStore(SQLBaseStore): LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - updates = [ - (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) - for stream_id, group_id, user_id, gtype, content_json in txn - ] + updates = cast( + List[Tuple[int, tuple]], + [ + (stream_id, (group_id, user_id, gtype, db_to_json(content_json))) + for stream_id, group_id, user_id, gtype, content_json in txn + ], + ) limited = False upto_token = current_id @@ -633,8 +657,8 @@ class GroupServerStore(GroupServerWorkerStore): self, group_id: str, room_id: str, - category_id: str, - order: int, + category_id: Optional[str], + order: Optional[int], is_public: Optional[bool], ) -> None: """Add (or update) room's entry in summary. @@ -661,11 +685,11 @@ class GroupServerStore(GroupServerWorkerStore): def _add_room_to_summary_txn( self, - txn, + txn: LoggingTransaction, group_id: str, room_id: str, - category_id: str, - order: int, + category_id: Optional[str], + order: Optional[int], is_public: Optional[bool], ) -> None: """Add (or update) room's entry in summary. @@ -750,7 +774,7 @@ class GroupServerStore(GroupServerWorkerStore): WHERE group_id = ? AND category_id = ? """ txn.execute(sql, (group_id, category_id)) - (order,) = txn.fetchone() + (order,) = cast(Tuple[int], txn.fetchone()) if existing: to_update = {} @@ -766,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore): "category_id": category_id, "room_id": room_id, }, - values=to_update, + updatevalues=to_update, ) else: if is_public is None: @@ -785,7 +809,7 @@ class GroupServerStore(GroupServerWorkerStore): ) async def remove_room_from_summary( - self, group_id: str, room_id: str, category_id: str + self, group_id: str, room_id: str, category_id: Optional[str] ) -> int: if category_id is None: category_id = _DEFAULT_CATEGORY_ID @@ -808,8 +832,8 @@ class GroupServerStore(GroupServerWorkerStore): is_public: Optional[bool], ) -> None: """Add/update room category for group""" - insertion_values = {} - update_values = {"category_id": category_id} # This cannot be empty + insertion_values: JsonDict = {} + update_values: JsonDict = {"category_id": category_id} # This cannot be empty if profile is None: insertion_values["profile"] = "{}" @@ -844,8 +868,8 @@ class GroupServerStore(GroupServerWorkerStore): is_public: Optional[bool], ) -> None: """Add/remove user role""" - insertion_values = {} - update_values = {"role_id": role_id} # This cannot be empty + insertion_values: JsonDict = {} + update_values: JsonDict = {"role_id": role_id} # This cannot be empty if profile is None: insertion_values["profile"] = "{}" @@ -876,8 +900,8 @@ class GroupServerStore(GroupServerWorkerStore): self, group_id: str, user_id: str, - role_id: str, - order: int, + role_id: Optional[str], + order: Optional[int], is_public: Optional[bool], ) -> None: """Add (or update) user's entry in summary. @@ -904,13 +928,13 @@ class GroupServerStore(GroupServerWorkerStore): def _add_user_to_summary_txn( self, - txn, + txn: LoggingTransaction, group_id: str, user_id: str, - role_id: str, - order: int, + role_id: Optional[str], + order: Optional[int], is_public: Optional[bool], - ): + ) -> None: """Add (or update) user's entry in summary. Args: @@ -989,7 +1013,7 @@ class GroupServerStore(GroupServerWorkerStore): WHERE group_id = ? AND role_id = ? """ txn.execute(sql, (group_id, role_id)) - (order,) = txn.fetchone() + (order,) = cast(Tuple[int], txn.fetchone()) if existing: to_update = {} @@ -1005,7 +1029,7 @@ class GroupServerStore(GroupServerWorkerStore): "role_id": role_id, "user_id": user_id, }, - values=to_update, + updatevalues=to_update, ) else: if is_public is None: @@ -1024,7 +1048,7 @@ class GroupServerStore(GroupServerWorkerStore): ) async def remove_user_from_summary( - self, group_id: str, user_id: str, role_id: str + self, group_id: str, user_id: str, role_id: Optional[str] ) -> int: if role_id is None: role_id = _DEFAULT_ROLE_ID @@ -1065,7 +1089,7 @@ class GroupServerStore(GroupServerWorkerStore): Optional if the user and group are on the same server """ - def _add_user_to_group_txn(txn): + def _add_user_to_group_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_insert_txn( txn, table="group_users", @@ -1108,7 +1132,7 @@ class GroupServerStore(GroupServerWorkerStore): await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) async def remove_user_from_group(self, group_id: str, user_id: str) -> None: - def _remove_user_from_group_txn(txn): + def _remove_user_from_group_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, table="group_users", @@ -1159,7 +1183,7 @@ class GroupServerStore(GroupServerWorkerStore): ) async def remove_room_from_group(self, group_id: str, room_id: str) -> None: - def _remove_room_from_group_txn(txn): + def _remove_room_from_group_txn(txn: LoggingTransaction) -> None: self.db_pool.simple_delete_txn( txn, table="group_rooms", @@ -1216,7 +1240,9 @@ class GroupServerStore(GroupServerWorkerStore): content = content or {} - def _register_user_group_membership_txn(txn, next_id): + def _register_user_group_membership_txn( + txn: LoggingTransaction, next_id: int + ) -> int: # TODO: Upsert? self.db_pool.simple_delete_txn( txn, @@ -1249,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore): ), }, ) - self._group_updates_stream_cache.entity_has_changed(user_id, next_id) + self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined] # TODO: Insert profile to ensure it comes down stream if its a join. @@ -1289,7 +1315,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - async with self._group_updates_id_gen.get_next() as next_id: + async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined] res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, @@ -1298,7 +1324,13 @@ class GroupServerStore(GroupServerWorkerStore): return res async def create_group( - self, group_id, user_id, name, avatar_url, short_description, long_description + self, + group_id: str, + user_id: str, + name: str, + avatar_url: str, + short_description: str, + long_description: str, ) -> None: await self.db_pool.simple_insert( table="groups", @@ -1313,7 +1345,7 @@ class GroupServerStore(GroupServerWorkerStore): desc="create_group", ) - async def update_group_profile(self, group_id, profile): + async def update_group_profile(self, group_id: str, profile: JsonDict) -> None: await self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, @@ -1361,8 +1393,8 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_attestation_renewal", ) - def get_group_stream_token(self): - return self._group_updates_id_gen.get_current_token() + def get_group_stream_token(self) -> int: + return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined] async def delete_group(self, group_id: str) -> None: """Deletes a group fully from the database. @@ -1371,7 +1403,7 @@ class GroupServerStore(GroupServerWorkerStore): group_id: The group ID to delete. """ - def _delete_group_txn(txn): + def _delete_group_txn(txn: LoggingTransaction) -> None: tables = [ "groups", "group_users", diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index cbba356b4a..322ed05390 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -156,7 +156,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): hs: "HomeServer", ): super().__init__(database, db_conn, hs) - self.server_name = hs.hostname + self.server_name: str = hs.hostname async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: """Get the metadata for a local piece of media diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e9a0cdc6be..216622964a 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, + LoggingTransaction, make_in_list_sql_clause, ) +from synapse.storage.databases.main.registration import RegistrationWorkerStore from synapse.util.caches.descriptors import cached from synapse.util.threepids import canonicalise_email @@ -56,7 +58,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): Number of current monthly active users """ - def _count_users(txn): + def _count_users(txn: LoggingTransaction) -> int: # Exclude app service users sql = """ SELECT COUNT(*) @@ -66,7 +68,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): WHERE (users.appservice_id IS NULL OR users.appservice_id = ''); """ txn.execute(sql) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction("count_users", _count_users) @@ -84,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """ - def _count_users_by_service(txn): + def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]: sql = """ SELECT COALESCE(appservice_id, 'native'), COUNT(*) FROM monthly_active_users @@ -93,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """ txn.execute(sql) - result = txn.fetchall() + result = cast(List[Tuple[str, int]], txn.fetchall()) return dict(result) return await self.db_pool.runInteraction( @@ -141,12 +143,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): ) @wrap_as_background_process("reap_monthly_active_users") - async def reap_monthly_active_users(self): + async def reap_monthly_active_users(self) -> None: """Cleans out monthly active user table to ensure that no stale entries exist. """ - def _reap_users(txn, reserved_users): + def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None: """ Args: reserved_users (tuple): reserved users to preserve @@ -210,10 +212,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): # is racy. # Have resolved to invalidate the whole cache for now and do # something about it if and when the perf becomes significant - self._invalidate_all_cache_and_stream( + self._invalidate_all_cache_and_stream( # type: ignore[attr-defined] txn, self.user_last_seen_monthly_active ) - self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) + self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined] reserved_users = await self.get_registered_reserved_users() await self.db_pool.runInteraction( @@ -221,7 +223,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): ) -class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): +class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore): def __init__( self, database: DatabasePool, @@ -242,13 +244,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], ) - def _initialise_reserved_users(self, txn, threepids): + def _initialise_reserved_users( + self, txn: LoggingTransaction, threepids: List[dict] + ) -> None: """Ensures that reserved threepids are accounted for in the MAU table, should be called on start up. Args: - txn (cursor): - threepids (list[dict]): List of threepid dicts to reserve + txn: + threepids: List of threepid dicts to reserve """ # XXX what is this function trying to achieve? It upserts into @@ -299,7 +303,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) - def upsert_monthly_active_user_txn(self, txn, user_id): + def upsert_monthly_active_user_txn( + self, txn: LoggingTransaction, user_id: str + ) -> None: """Updates or inserts monthly active user member We consciously do not call is_support_txn from this method because it @@ -336,7 +342,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): txn, self.user_last_seen_monthly_active, (user_id,) ) - async def populate_monthly_active_users(self, user_id): + async def populate_monthly_active_users(self, user_id: str) -> None: """Checks on the state of monthly active user limits and optionally add the user to the monthly active tables @@ -345,7 +351,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): """ if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group - is_guest = await self.is_guest(user_id) + is_guest = await self.is_guest(user_id) # type: ignore[attr-defined] if is_guest: return is_trial = await self.is_trial_user(user_id) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index bf0b903af2..e6f97aeece 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -24,10 +24,9 @@ from typing import ( Optional, Set, Tuple, + cast, ) -from twisted.internet import defer - from synapse.api.constants import ReceiptTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream @@ -38,7 +37,11 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdTracker, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -58,6 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore): hs: "HomeServer", ): self._instance_name = hs.get_instance_name() + self._receipts_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): self._can_write_to_receipts = ( @@ -161,7 +165,7 @@ class ReceiptsWorkerStore(SQLBaseStore): " AND user_id = ?" ) txn.execute(sql, (user_id,)) - return txn.fetchall() + return cast(List[Tuple[str, str, int, int]], txn.fetchall()) rows = await self.db_pool.runInteraction( "get_receipts_for_user_with_orderings", f @@ -257,7 +261,7 @@ class ReceiptsWorkerStore(SQLBaseStore): if not rows: return [] - content = {} + content: JsonDict = {} for row in rows: content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ row["user_id"] @@ -305,7 +309,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "_get_linearized_receipts_for_rooms", f ) - results = {} + results: JsonDict = {} for row in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. @@ -370,7 +374,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "get_linearized_receipts_for_all_rooms", f ) - results = {} + results: JsonDict = {} for row in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. @@ -399,7 +403,7 @@ class ReceiptsWorkerStore(SQLBaseStore): """ if last_id == current_id: - return defer.succeed([]) + return [] def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]: sql = """ @@ -453,7 +457,10 @@ class ReceiptsWorkerStore(SQLBaseStore): """ txn.execute(sql, (last_id, current_id, limit)) - updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn] + updates = cast( + List[Tuple[int, list]], + [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn], + ) limited = False upper_bound = current_id @@ -496,7 +503,13 @@ class ReceiptsWorkerStore(SQLBaseStore): self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) self.get_receipts_for_room.invalidate((room_id, receipt_type)) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: if stream_name == ReceiptsStream.NAME: self._receipts_id_gen.advance(instance_name, token) for row in rows: @@ -584,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) if receipt_type == ReceiptTypes.READ and stream_ordering is not None: - self._remove_old_push_actions_before_txn( + self._remove_old_push_actions_before_txn( # type: ignore[attr-defined] txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering ) @@ -637,7 +650,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "insert_receipt_conv", graph_to_linear ) - async with self._receipts_id_gen.get_next() as stream_id: + async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined] event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index a698d10cc5..7f3d190e94 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -22,6 +22,7 @@ import attr from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError +from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.database import ( DatabasePool, @@ -123,7 +124,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ): super().__init__(database, db_conn, hs) - self.config = hs.config + self.config: HomeServerConfig = hs.config # Note: we don't check this sequence for consistency as we'd have to # call `find_max_generated_user_id_localpart` each time, which is diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index c4869d64e6..b2295fd51f 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -27,7 +27,6 @@ from typing import ( ) import attr -from frozendict import frozendict from synapse.api.constants import RelationTypes from synapse.events import EventBase @@ -41,45 +40,15 @@ from synapse.storage.database import ( from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine from synapse.storage.relations import AggregationPaginationToken, PaginationChunk -from synapse.types import JsonDict, RoomStreamToken, StreamToken +from synapse.types import RoomStreamToken, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _ThreadAggregation: - # The latest event in the thread. - latest_event: EventBase - # The latest edit to the latest event in the thread. - latest_edit: Optional[EventBase] - # The total number of events in the thread. - count: int - # True if the current user has sent an event to the thread. - current_user_participated: bool - - -@attr.s(slots=True, auto_attribs=True) -class BundledAggregations: - """ - The bundled aggregations for an event. - - Some values require additional processing during serialization. - """ - - annotations: Optional[JsonDict] = None - references: Optional[JsonDict] = None - replace: Optional[EventBase] = None - thread: Optional[_ThreadAggregation] = None - - def __bool__(self) -> bool: - return bool(self.annotations or self.references or self.replace or self.thread) - - class RelationsWorkerStore(SQLBaseStore): def __init__( self, @@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") - async def _get_applicable_edits( + async def get_applicable_edits( self, event_ids: Collection[str] ) -> Dict[str, Optional[EventBase]]: """Get the most recent edit (if any) that has happened for the given @@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") - async def _get_thread_summaries( + async def get_thread_summaries( self, event_ids: Collection[str] ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]: """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event. @@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore): latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] # Check to see if any of those events are edited. - latest_edits = await self._get_applicable_edits(latest_event_ids.values()) + latest_edits = await self.get_applicable_edits(latest_event_ids.values()) # Map to the event IDs to the thread summary. # @@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_thread_participated", list_name="event_ids") - async def _get_threads_participated( + async def get_threads_participated( self, event_ids: Collection[str], user_id: str ) -> Dict[str, bool]: """Get whether the requesting user participated in the given threads. @@ -766,114 +735,6 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) - async def _get_bundled_aggregation_for_event( - self, event: EventBase, user_id: str - ) -> Optional[BundledAggregations]: - """Generate bundled aggregations for an event. - - Note that this does not use a cache, but depends on cached methods. - - Args: - event: The event to calculate bundled aggregations for. - user_id: The user requesting the bundled aggregations. - - Returns: - The bundled aggregations for an event, if bundled aggregations are - enabled and the event can have bundled aggregations. - """ - - # Do not bundle aggregations for an event which represents an edit or an - # annotation. It does not make sense for them to have related events. - relates_to = event.content.get("m.relates_to") - if isinstance(relates_to, (dict, frozendict)): - relation_type = relates_to.get("rel_type") - if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): - return None - - event_id = event.event_id - room_id = event.room_id - - # The bundled aggregations to include, a mapping of relation type to a - # type-specific value. Some types include the direct return type here - # while others need more processing during serialization. - aggregations = BundledAggregations() - - annotations = await self.get_aggregation_groups_for_event(event_id, room_id) - if annotations.chunk: - aggregations.annotations = await annotations.to_dict( - cast("DataStore", self) - ) - - references = await self.get_relations_for_event( - event_id, event, room_id, RelationTypes.REFERENCE, direction="f" - ) - if references.chunk: - aggregations.references = await references.to_dict(cast("DataStore", self)) - - # Store the bundled aggregations in the event metadata for later use. - return aggregations - - async def get_bundled_aggregations( - self, events: Iterable[EventBase], user_id: str - ) -> Dict[str, BundledAggregations]: - """Generate bundled aggregations for events. - - Args: - events: The iterable of events to calculate bundled aggregations for. - user_id: The user requesting the bundled aggregations. - - Returns: - A map of event ID to the bundled aggregation for the event. Not all - events may have bundled aggregations in the results. - """ - # De-duplicate events by ID to handle the same event requested multiple times. - # - # State events do not get bundled aggregations. - events_by_id = { - event.event_id: event for event in events if not event.is_state() - } - - # event ID -> bundled aggregation in non-serialized form. - results: Dict[str, BundledAggregations] = {} - - # Fetch other relations per event. - for event in events_by_id.values(): - event_result = await self._get_bundled_aggregation_for_event(event, user_id) - if event_result: - results[event.event_id] = event_result - - # Fetch any edits (but not for redacted events). - edits = await self._get_applicable_edits( - [ - event_id - for event_id, event in events_by_id.items() - if not event.internal_metadata.is_redacted() - ] - ) - for event_id, edit in edits.items(): - results.setdefault(event_id, BundledAggregations()).replace = edit - - # Fetch thread summaries. - summaries = await self._get_thread_summaries(events_by_id.keys()) - # Only fetch participated for a limited selection based on what had - # summaries. - participated = await self._get_threads_participated(summaries.keys(), user_id) - for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event, edit = summary - results.setdefault( - event_id, BundledAggregations() - ).thread = _ThreadAggregation( - latest_event=latest_thread_event, - latest_edit=edit, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=participated[event_id], - ) - - return results - class RelationsStore(RelationsWorkerStore): pass diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 94068940b9..18b1acd9e1 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -34,6 +34,7 @@ import attr from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions +from synapse.config.homeserver import HomeServerConfig from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -98,7 +99,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ): super().__init__(database, db_conn, hs) - self.config = hs.config + self.config: HomeServerConfig = hs.config async def store_room( self, diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index bef675b845..3248da5356 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" +@attr.s(frozen=True, slots=True, auto_attribs=True) +class EventIdMembership: + """Returned by `get_membership_from_event_ids`""" + + user_id: str + membership: str + + class RoomMemberWorkerStore(EventsWorkerStore): def __init__( self, @@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): retcols=("user_id", "display_name", "avatar_url", "event_id"), keyvalues={"membership": Membership.JOIN}, batch_size=500, - desc="_get_membership_from_event_ids", + desc="_get_joined_profiles_from_event_ids", ) return { @@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore): return set(room_ids) + @cached(max_entries=5000) + async def _get_membership_from_event_id( + self, member_event_id: str + ) -> Optional[EventIdMembership]: + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_membership_from_event_id", list_name="member_event_ids" + ) async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] - ) -> List[dict]: - """Get user_id and membership of a set of event IDs.""" + ) -> Dict[str, Optional[EventIdMembership]]: + """Get user_id and membership of a set of event IDs. + + Returns: + Mapping from event ID to `EventIdMembership` if the event is a + membership event, otherwise the value is None. + """ - return await self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, @@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): desc="get_membership_from_event_ids", ) + return { + row["event_id"]: EventIdMembership( + membership=row["membership"], user_id=row["user_id"] + ) + for row in rows + } + async def is_local_host_in_room_ignoring_users( self, room_id: str, ignore_users: Collection[str] ) -> bool: diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index bb41beb827..79abe758e6 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -14,7 +14,7 @@ import logging import re -from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set +from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set import attr @@ -74,7 +74,7 @@ class SearchWorkerStore(SQLBaseStore): " VALUES (?,?,?,to_tsvector('english', ?),?,?)" ) - args = ( + args1 = ( ( entry.event_id, entry.room_id, @@ -86,14 +86,14 @@ class SearchWorkerStore(SQLBaseStore): for entry in entries ) - txn.execute_batch(sql, args) + txn.execute_batch(sql, args1) elif isinstance(self.database_engine, Sqlite3Engine): sql = ( "INSERT INTO event_search (event_id, room_id, key, value)" " VALUES (?,?,?,?)" ) - args = ( + args2 = ( ( entry.event_id, entry.room_id, @@ -102,7 +102,7 @@ class SearchWorkerStore(SQLBaseStore): ) for entry in entries ) - txn.execute_batch(sql, args) + txn.execute_batch(sql, args2) else: # This should be unreachable. @@ -427,7 +427,7 @@ class SearchStore(SearchBackgroundUpdateStore): search_query = _parse_query(self.database_engine, search_term) - args = [] + args: List[Any] = [] # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. @@ -496,7 +496,7 @@ class SearchStore(SearchBackgroundUpdateStore): # We set redact_behaviour to BLOCK here to prevent redacted events being returned in # search results (which is a data leak) - events = await self.get_events_as_list( + events = await self.get_events_as_list( # type: ignore[attr-defined] [r["event_id"] for r in results], redact_behaviour=EventRedactBehaviour.BLOCK, ) @@ -530,7 +530,7 @@ class SearchStore(SearchBackgroundUpdateStore): room_ids: Collection[str], search_term: str, keys: Iterable[str], - limit, + limit: int, pagination_token: Optional[str] = None, ) -> JsonDict: """Performs a full text search over events with given keys. @@ -549,7 +549,7 @@ class SearchStore(SearchBackgroundUpdateStore): search_query = _parse_query(self.database_engine, search_term) - args = [] + args: List[Any] = [] # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. @@ -573,9 +573,9 @@ class SearchStore(SearchBackgroundUpdateStore): if pagination_token: try: - origin_server_ts, stream = pagination_token.split(",") - origin_server_ts = int(origin_server_ts) - stream = int(stream) + origin_server_ts_str, stream_str = pagination_token.split(",") + origin_server_ts = int(origin_server_ts_str) + stream = int(stream_str) except Exception: raise SynapseError(400, "Invalid pagination token") @@ -654,7 +654,7 @@ class SearchStore(SearchBackgroundUpdateStore): # We set redact_behaviour to BLOCK here to prevent redacted events being returned in # search results (which is a data leak) - events = await self.get_events_as_list( + events = await self.get_events_as_list( # type: ignore[attr-defined] [r["event_id"] for r in results], redact_behaviour=EventRedactBehaviour.BLOCK, ) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 417aef1dbc..28460fd364 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -14,7 +14,7 @@ # limitations under the License. import collections.abc import logging -from typing import TYPE_CHECKING, Iterable, Optional, Set +from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError @@ -29,7 +29,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.state import StateFilter -from synapse.types import StateMap +from synapse.types import JsonDict, StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList @@ -241,7 +241,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # We delegate to the cached version return await self.get_current_state_ids(room_id) - def _get_filtered_current_state_ids_txn(txn): + def _get_filtered_current_state_ids_txn( + txn: LoggingTransaction, + ) -> StateMap[str]: results = {} sql = """ SELECT type, state_key, event_id FROM current_state_events @@ -281,11 +283,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): event_id = state.get((EventTypes.CanonicalAlias, "")) if not event_id: - return + return None event = await self.get_event(event_id, allow_none=True) if not event: - return + return None return event.content.get("canonical_alias") @@ -304,7 +306,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): list_name="event_ids", num_args=1, ) - async def _get_state_group_for_events(self, event_ids): + async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict: """Returns mapping event_id -> state_group""" rows = await self.db_pool.simple_select_many_batch( table="event_to_state_groups", @@ -355,7 +357,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): ): super().__init__(database, db_conn, hs) - self.server_name = hs.hostname + self.server_name: str = hs.hostname self.db_pool.updates.register_background_index_update( self.CURRENT_STATE_INDEX_UPDATE_NAME, @@ -375,7 +377,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): self._background_remove_left_rooms, ) - async def _background_remove_left_rooms(self, progress, batch_size): + async def _background_remove_left_rooms( + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to delete rows from `current_state_events` and `event_forward_extremities` tables of rooms that the server is no longer joined to. @@ -383,7 +387,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): last_room_id = progress.get("last_room_id", "") - def _background_remove_left_rooms_txn(txn): + def _background_remove_left_rooms_txn( + txn: LoggingTransaction, + ) -> Tuple[bool, Set[str]]: # get a batch of room ids to consider sql = """ SELECT DISTINCT room_id FROM current_state_events diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 427ae1f649..b95dbef678 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -108,7 +108,7 @@ class StatsStore(StateDeltasStore): ): super().__init__(database, db_conn, hs) - self.server_name = hs.hostname + self.server_name: str = hs.hostname self.clock = self.hs.get_clock() self.stats_enabled = hs.config.stats.stats_enabled diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index e7fddd2426..df772d4721 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -26,6 +26,8 @@ from typing import ( cast, ) +from typing_extensions import TypedDict + from synapse.api.errors import StoreError if TYPE_CHECKING: @@ -40,7 +42,12 @@ from synapse.storage.database import ( from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id +from synapse.types import ( + JsonDict, + UserProfile, + get_domain_from_id, + get_localpart_from_id, +) from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -61,7 +68,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) -> None: super().__init__(database, db_conn, hs) - self.server_name = hs.hostname + self.server_name: str = hs.hostname self.db_pool.updates.register_background_update_handler( "populate_user_directory_createtables", @@ -591,6 +598,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) +class SearchResult(TypedDict): + limited: bool + results: List[UserProfile] + + class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # How many records do we calculate before sending it to # add_users_who_share_private_rooms? @@ -718,7 +730,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - async def get_shared_rooms_for_users( + async def get_mutual_rooms_for_users( self, user_id: str, other_user_id: str ) -> Set[str]: """ @@ -732,7 +744,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): A set of room ID's that the users share. """ - def _get_shared_rooms_for_users_txn( + def _get_mutual_rooms_for_users_txn( txn: LoggingTransaction, ) -> List[Dict[str, str]]: txn.execute( @@ -756,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): return rows rows = await self.db_pool.runInteraction( - "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn + "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn ) return {row["room_id"] for row in rows} @@ -777,7 +789,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): async def search_user_dir( self, user_id: str, search_term: str, limit: int - ) -> JsonDict: + ) -> SearchResult: """Searches for users in directory Returns: @@ -910,8 +922,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # This should be unreachable. raise Exception("Unrecognized database engine") - results = await self.db_pool.execute( - "search_user_dir", self.db_pool.cursor_to_dict, sql, *args + results = cast( + List[UserProfile], + await self.db_pool.execute( + "search_user_dir", self.db_pool.cursor_to_dict, sql, *args + ), ) limited = len(results) > limit |