diff options
author | Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> | 2022-05-17 16:29:06 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-17 15:29:06 +0100 |
commit | 6edefef60289cc54e17fd6af838eb66c4973f5f5 (patch) | |
tree | fd254e6d378e0f877c39826b89e8db47785abcdb /synapse/storage | |
parent | Add a new room version for MSC3787's knock+restricted join rule (#12623) (diff) | |
download | synapse-6edefef60289cc54e17fd6af838eb66c4973f5f5.tar.xz |
Add some type hints to datastore (#12717)
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/databases/main/__init__.py | 8 | ||||
-rw-r--r-- | synapse/storage/databases/main/metrics.py | 56 | ||||
-rw-r--r-- | synapse/storage/databases/main/push_rule.py | 184 | ||||
-rw-r--r-- | synapse/storage/databases/main/roommember.py | 126 |
4 files changed, 229 insertions, 145 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 5895b89202..d545a1c002 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -26,11 +26,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import ( - IdGenerator, - MultiWriterIdGenerator, - StreamIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -155,8 +151,6 @@ class DataStore( ], ) - self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") - self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._group_updates_id_gen = StreamIdGenerator( db_conn, "local_group_updates", "stream_id" ) diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index d03555a585..14294a0bb8 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -14,16 +14,19 @@ import calendar import logging import time -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, List, Tuple, cast from synapse.metrics import GaugeBucketCollector from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) -from synapse.storage.types import Cursor if TYPE_CHECKING: from synapse.server import HomeServer @@ -73,7 +76,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): @wrap_as_background_process("read_forward_extremities") async def _read_forward_extremities(self) -> None: - def fetch(txn): + def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]: txn.execute( """ SELECT t1.c, t2.c @@ -86,7 +89,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): ) t2 ON t1.room_id = t2.room_id """ ) - return txn.fetchall() + return cast(List[Tuple[int, int]], txn.fetchall()) res = await self.db_pool.runInteraction("read_forward_extremities", fetch) @@ -104,20 +107,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): call to this function, it will return None. """ - def _count_messages(txn): + def _count_messages(txn: LoggingTransaction) -> int: sql = """ SELECT COUNT(*) FROM events WHERE type = 'm.room.encrypted' AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages) async def count_daily_sent_e2ee_messages(self) -> int: - def _count_messages(txn): + def _count_messages(txn: LoggingTransaction) -> int: # This is good enough as if you have silly characters in your own # hostname then that's your own fault. like_clause = "%:" + self.hs.hostname @@ -130,7 +133,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): """ txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction( @@ -138,14 +141,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): ) async def count_daily_active_e2ee_rooms(self) -> int: - def _count(txn): + def _count(txn: LoggingTransaction) -> int: sql = """ SELECT COUNT(DISTINCT room_id) FROM events WHERE type = 'm.room.encrypted' AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction( @@ -160,20 +163,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): call to this function, it will return None. """ - def _count_messages(txn): + def _count_messages(txn: LoggingTransaction) -> int: sql = """ SELECT COUNT(*) FROM events WHERE type = 'm.room.message' AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction("count_messages", _count_messages) async def count_daily_sent_messages(self) -> int: - def _count_messages(txn): + def _count_messages(txn: LoggingTransaction) -> int: # This is good enough as if you have silly characters in your own # hostname then that's your own fault. like_clause = "%:" + self.hs.hostname @@ -186,7 +189,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): """ txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction( @@ -194,14 +197,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): ) async def count_daily_active_rooms(self) -> int: - def _count(txn): + def _count(txn: LoggingTransaction) -> int: sql = """ SELECT COUNT(DISTINCT room_id) FROM events WHERE type = 'm.room.message' AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return count return await self.db_pool.runInteraction("count_daily_active_rooms", _count) @@ -227,7 +230,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): "count_monthly_users", self._count_users, thirty_days_ago ) - def _count_users(self, txn: Cursor, time_from: int) -> int: + def _count_users(self, txn: LoggingTransaction, time_from: int) -> int: """ Returns number of users seen in the past time_from period """ @@ -242,7 +245,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): # Mypy knows that fetchone() might return None if there are no rows. # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always # returns exactly one row. - (count,) = txn.fetchone() # type: ignore[misc] + (count,) = cast(Tuple[int], txn.fetchone()) return count async def count_r30_users(self) -> Dict[str, int]: @@ -256,7 +259,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): A mapping of counts globally as well as broken out by platform. """ - def _count_r30_users(txn): + def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]: thirty_days_in_secs = 86400 * 30 now = int(self._clock.time()) thirty_days_ago_in_secs = now - thirty_days_in_secs @@ -321,7 +324,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) results["all"] = count return results @@ -348,7 +351,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): - "web" (any web application -- it's not possible to distinguish Element Web here) """ - def _count_r30v2_users(txn): + def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]: thirty_days_in_secs = 86400 * 30 now = int(self._clock.time()) sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs @@ -445,11 +448,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): thirty_days_in_secs * 1000, ), ) - row = txn.fetchone() - if row is None: - results["all"] = 0 - else: - results["all"] = row[0] + (count,) = cast(Tuple[int], txn.fetchone()) + results["all"] = count return results @@ -471,7 +471,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): Generates daily visit data for use in cohort/ retention analysis """ - def _generate_user_daily_visits(txn): + def _generate_user_daily_visits(txn: LoggingTransaction) -> None: logger.info("Calling _generate_user_daily_visits") today_start = self._get_start_of_day() a_day_in_milliseconds = 24 * 60 * 60 * 1000 diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 4ed913e248..0e2855fb44 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -14,14 +14,18 @@ # limitations under the License. import abc import logging -from typing import TYPE_CHECKING, Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.pusher import PusherWorkerStore @@ -30,9 +34,12 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, AbstractStreamIdTracker, + IdGenerator, StreamIdGenerator, ) +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -57,7 +64,11 @@ def _is_experimental_rule_enabled( return True -def _load_rules(rawrules, enabled_map, experimental_config: ExperimentalConfig): +def _load_rules( + rawrules: List[JsonDict], + enabled_map: Dict[str, bool], + experimental_config: ExperimentalConfig, +) -> List[JsonDict]: ruleslist = [] for rawrule in rawrules: rule = dict(rawrule) @@ -137,7 +148,7 @@ class PushRulesWorkerStore( ) @abc.abstractmethod - def get_max_push_rules_stream_id(self): + def get_max_push_rules_stream_id(self) -> int: """Get the position of the push rules stream. Returns: @@ -146,7 +157,7 @@ class PushRulesWorkerStore( raise NotImplementedError() @cached(max_entries=5000) - async def get_push_rules_for_user(self, user_id): + async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]: rows = await self.db_pool.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, @@ -168,7 +179,7 @@ class PushRulesWorkerStore( return _load_rules(rows, enabled_map, self.hs.config.experimental) @cached(max_entries=5000) - async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]: + async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, @@ -184,13 +195,13 @@ class PushRulesWorkerStore( return False else: - def have_push_rules_changed_txn(txn): + def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool: sql = ( "SELECT COUNT(stream_id) FROM push_rules_stream" " WHERE user_id = ? AND ? < stream_id" ) txn.execute(sql, (user_id, last_id)) - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) return bool(count) return await self.db_pool.runInteraction( @@ -202,11 +213,13 @@ class PushRulesWorkerStore( list_name="user_ids", num_args=1, ) - async def bulk_get_push_rules(self, user_ids): + async def bulk_get_push_rules( + self, user_ids: Collection[str] + ) -> Dict[str, List[JsonDict]]: if not user_ids: return {} - results = {user_id: [] for user_id in user_ids} + results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids} rows = await self.db_pool.simple_select_many_batch( table="push_rules", @@ -250,7 +263,7 @@ class PushRulesWorkerStore( condition["pattern"] = new_room_id # Add the rule for the new room - await self.add_push_rule( + await self.add_push_rule( # type: ignore[attr-defined] user_id=user_id, rule_id=new_rule_id, priority_class=rule["priority_class"], @@ -286,11 +299,13 @@ class PushRulesWorkerStore( list_name="user_ids", num_args=1, ) - async def bulk_get_push_rules_enabled(self, user_ids): + async def bulk_get_push_rules_enabled( + self, user_ids: Collection[str] + ) -> Dict[str, Dict[str, bool]]: if not user_ids: return {} - results = {user_id: {} for user_id in user_ids} + results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids} rows = await self.db_pool.simple_select_many_batch( table="push_rules_enable", @@ -306,7 +321,7 @@ class PushRulesWorkerStore( async def get_all_push_rule_updates( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: """Get updates for push_rules replication stream. Args: @@ -331,7 +346,9 @@ class PushRulesWorkerStore( if last_id == current_id: return [], current_id, False - def get_all_push_rule_updates_txn(txn): + def get_all_push_rule_updates_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: sql = """ SELECT stream_id, user_id FROM push_rules_stream @@ -340,7 +357,10 @@ class PushRulesWorkerStore( LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - updates = [(stream_id, (user_id,)) for stream_id, user_id in txn] + updates = cast( + List[Tuple[int, Tuple[str]]], + [(stream_id, (user_id,)) for stream_id, user_id in txn], + ) limited = False upper_bound = current_id @@ -356,15 +376,30 @@ class PushRulesWorkerStore( class PushRuleStore(PushRulesWorkerStore): + # Because we have write access, this will be a StreamIdGenerator + # (see PushRulesWorkerStore.__init__) + _push_rules_stream_id_gen: AbstractStreamIdGenerator + + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") + self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") + async def add_push_rule( self, - user_id, - rule_id, - priority_class, - conditions, - actions, - before=None, - after=None, + user_id: str, + rule_id: str, + priority_class: int, + conditions: List[Dict[str, str]], + actions: List[Union[JsonDict, str]], + before: Optional[str] = None, + after: Optional[str] = None, ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) @@ -400,17 +435,17 @@ class PushRuleStore(PushRulesWorkerStore): def _add_push_rule_relative_txn( self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - before, - after, - ): + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + user_id: str, + rule_id: str, + priority_class: int, + conditions_json: str, + actions_json: str, + before: str, + after: str, + ) -> None: # Lock the table since otherwise we'll have annoying races between the # SELECT here and the UPSERT below. self.database_engine.lock_table(txn, "push_rules") @@ -470,15 +505,15 @@ class PushRuleStore(PushRulesWorkerStore): def _add_push_rule_highest_priority_txn( self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - ): + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + user_id: str, + rule_id: str, + priority_class: int, + conditions_json: str, + actions_json: str, + ) -> None: # Lock the table since otherwise we'll have annoying races between the # SELECT here and the UPSERT below. self.database_engine.lock_table(txn, "push_rules") @@ -510,17 +545,17 @@ class PushRuleStore(PushRulesWorkerStore): def _upsert_push_rule_txn( self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - priority, - conditions_json, - actions_json, - update_stream=True, - ): + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + user_id: str, + rule_id: str, + priority_class: int, + priority: int, + conditions_json: str, + actions_json: str, + update_stream: bool = True, + ) -> None: """Specialised version of simple_upsert_txn that picks a push_rule_id using the _push_rule_id_gen if it needs to insert the rule. It assumes that the "push_rules" table is locked""" @@ -600,7 +635,11 @@ class PushRuleStore(PushRulesWorkerStore): rule_id: The rule_id of the rule to be deleted """ - def delete_push_rule_txn(txn, stream_id, event_stream_ordering): + def delete_push_rule_txn( + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + ) -> None: # we don't use simple_delete_one_txn because that would fail if the # user did not have a push_rule_enable row. self.db_pool.simple_delete_txn( @@ -661,14 +700,14 @@ class PushRuleStore(PushRulesWorkerStore): def _set_push_rule_enabled_txn( self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - enabled, - is_default_rule, - ): + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + user_id: str, + rule_id: str, + enabled: bool, + is_default_rule: bool, + ) -> None: new_id = self._push_rules_enable_id_gen.get_next() if not is_default_rule: @@ -740,7 +779,11 @@ class PushRuleStore(PushRulesWorkerStore): """ actions_json = json_encoder.encode(actions) - def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): + def set_push_rule_actions_txn( + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + ) -> None: if is_default_rule: # Add a dummy rule to the rules table with the user specified # actions. @@ -794,8 +837,15 @@ class PushRuleStore(PushRulesWorkerStore): ) def _insert_push_rules_update_txn( - self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None - ): + self, + txn: LoggingTransaction, + stream_id: int, + event_stream_ordering: int, + user_id: str, + rule_id: str, + op: str, + data: Optional[JsonDict] = None, + ) -> None: values = { "stream_id": stream_id, "event_stream_ordering": event_stream_ordering, @@ -814,5 +864,5 @@ class PushRuleStore(PushRulesWorkerStore): self.push_rules_stream_cache.entity_has_changed, user_id, stream_id ) - def get_max_push_rules_stream_id(self): + def get_max_push_rules_stream_id(self) -> int: return self._push_rules_stream_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 48e83592e7..608d40dfa1 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -37,7 +37,12 @@ from synapse.metrics.background_process_metrics import ( wrap_as_background_process, ) from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( @@ -46,7 +51,7 @@ from synapse.storage.roommember import ( ProfileInfo, RoomsForUser, ) -from synapse.types import PersistedEventPosition, get_domain_from_id +from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -115,7 +120,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @wrap_as_background_process("_count_known_servers") - async def _count_known_servers(self): + async def _count_known_servers(self) -> int: """ Count the servers that this server knows about. @@ -123,7 +128,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): `synapse_federation_known_servers` LaterGauge to collect. """ - def _transact(txn): + def _transact(txn: LoggingTransaction) -> int: if isinstance(self.database_engine, Sqlite3Engine): query = """ SELECT COUNT(DISTINCT substr(out.user_id, pos+1)) @@ -150,7 +155,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): self._known_servers_count = max([count, 1]) return self._known_servers_count - def _check_safe_current_state_events_membership_updated_txn(self, txn): + def _check_safe_current_state_events_membership_updated_txn( + self, txn: LoggingTransaction + ) -> None: """Checks if it is safe to assume the new current_state_events membership column is up to date """ @@ -182,7 +189,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): "get_users_in_room", self.get_users_in_room_txn, room_id ) - def get_users_in_room_txn(self, txn, room_id: str) -> List[str]: + def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]: # If we can assume current_state_events.membership is up to date # then we can avoid a join, which is a Very Good Thing given how # frequently this function gets called. @@ -222,7 +229,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): A mapping from user ID to ProfileInfo. """ - def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]: + def _get_users_in_room_with_profiles( + txn: LoggingTransaction, + ) -> Dict[str, ProfileInfo]: sql = """ SELECT state_key, display_name, avatar_url FROM room_memberships as m INNER JOIN current_state_events as c @@ -250,7 +259,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): dict of membership states, pointing to a MemberSummary named tuple. """ - def _get_room_summary_txn(txn): + def _get_room_summary_txn( + txn: LoggingTransaction, + ) -> Dict[str, MemberSummary]: # first get counts. # We do this all in one transaction to keep the cache small. # FIXME: get rid of this when we have room_stats @@ -279,7 +290,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ txn.execute(sql, (room_id,)) - res = {} + res: Dict[str, MemberSummary] = {} for count, membership in txn: res.setdefault(membership, MemberSummary([], count)) @@ -400,7 +411,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): def _get_rooms_for_local_user_where_membership_is_txn( self, - txn, + txn: LoggingTransaction, user_id: str, membership_list: List[str], ) -> List[RoomsForUser]: @@ -488,7 +499,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) def _get_rooms_for_user_with_stream_ordering_txn( - self, txn, user_id: str + self, txn: LoggingTransaction, user_id: str ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: # We use `current_state_events` here and not `local_current_membership` # as a) this gets called with remote users and b) this only gets called @@ -542,7 +553,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) def _get_rooms_for_users_with_stream_ordering_txn( - self, txn, user_ids: Collection[str] + self, txn: LoggingTransaction, user_ids: Collection[str] ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: clause, args = make_in_list_sql_clause( @@ -575,7 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, [Membership.JOIN] + args) - result = {user_id: set() for user_id in user_ids} + result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = { + user_id: set() for user_id in user_ids + } for user_id, room_id, instance, stream_id in txn: result[user_id].add( GetRoomsForUserWithStreamOrdering( @@ -595,7 +608,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): if not user_ids: return set() - def _get_users_server_still_shares_room_with_txn(txn): + def _get_users_server_still_shares_room_with_txn( + txn: LoggingTransaction, + ) -> Set[str]: sql = """ SELECT state_key FROM current_state_events WHERE @@ -657,7 +672,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): async def get_joined_users_from_context( self, event: EventBase, context: EventContext ) -> Dict[str, ProfileInfo]: - state_group = context.state_group + state_group: Union[object, int] = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -666,14 +681,16 @@ class RoomMemberWorkerStore(EventsWorkerStore): state_group = object() current_state_ids = await context.get_current_state_ids() + assert current_state_ids is not None + assert state_group is not None return await self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) async def get_joined_users_from_state( - self, room_id, state_entry + self, room_id: str, state_entry: "_StateCacheEntry" ) -> Dict[str, ProfileInfo]: - state_group = state_entry.state_group + state_group: Union[object, int] = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -681,6 +698,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() + assert state_group is not None with Measure(self._clock, "get_joined_users_from_state"): return await self._get_joined_users_from_context( room_id, state_group, state_entry.state, context=state_entry @@ -689,12 +707,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000) async def _get_joined_users_from_context( self, - room_id, - state_group, - current_state_ids, - cache_context, - event=None, - context=None, + room_id: str, + state_group: Union[object, int], + current_state_ids: StateMap[str], + cache_context: _CacheContext, + event: Optional[EventBase] = None, + context: Optional[Union[EventContext, "_StateCacheEntry"]] = None, ) -> Dict[str, ProfileInfo]: # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states @@ -765,14 +783,18 @@ class RoomMemberWorkerStore(EventsWorkerStore): return users_in_room @cached(max_entries=10000) - def _get_joined_profile_from_event_id(self, event_id): + def _get_joined_profile_from_event_id( + self, event_id: str + ) -> Optional[Tuple[str, ProfileInfo]]: raise NotImplementedError() @cachedList( cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids", ) - async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): + async def _get_joined_profiles_from_event_ids( + self, event_ids: Iterable[str] + ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]: """For given set of member event_ids check if they point to a join event and if so return the associated user and profile info. @@ -780,8 +802,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): event_ids: The member event IDs to lookup Returns: - dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID - to `user_id` and ProfileInfo (or None if not join event). + Map from event ID to `user_id` and ProfileInfo (or None if not join event). """ rows = await self.db_pool.simple_select_many_batch( @@ -847,8 +868,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): return True - async def get_joined_hosts(self, room_id: str, state_entry): - state_group = state_entry.state_group + async def get_joined_hosts( + self, room_id: str, state_entry: "_StateCacheEntry" + ) -> FrozenSet[str]: + state_group: Union[object, int] = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -856,6 +879,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() + assert state_group is not None with Measure(self._clock, "get_joined_hosts"): return await self._get_joined_hosts( room_id, state_group, state_entry=state_entry @@ -863,7 +887,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(num_args=2, max_entries=10000, iterable=True) async def _get_joined_hosts( - self, room_id: str, state_group: int, state_entry: "_StateCacheEntry" + self, + room_id: str, + state_group: Union[object, int], + state_entry: "_StateCacheEntry", ) -> FrozenSet[str]: # We don't use `state_group`, it's there so that we can cache based on # it. However, its important that its never None, since two @@ -881,7 +908,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # `get_joined_hosts` is called with the "current" state group for the # room, and so consecutive calls will be for consecutive state groups # which point to the previous state group. - cache = await self._get_joined_hosts_cache(room_id) + cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc] # If the state group in the cache matches, we already have the data we need. if state_entry.state_group == cache.state_group: @@ -897,6 +924,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): elif state_entry.prev_group == cache.state_group: # The cached work is for the previous state group, so we work out # the delta. + assert state_entry.delta_ids is not None for (typ, state_key), event_id in state_entry.delta_ids.items(): if typ != EventTypes.Member: continue @@ -942,7 +970,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): Returns False if they have since re-joined.""" - def f(txn): + def f(txn: LoggingTransaction) -> int: sql = ( "SELECT" " COUNT(*)" @@ -973,7 +1001,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): The forgotten rooms. """ - def _get_forgotten_rooms_for_user_txn(txn): + def _get_forgotten_rooms_for_user_txn(txn: LoggingTransaction) -> Set[str]: # This is a slightly convoluted query that first looks up all rooms # that the user has forgotten in the past, then rechecks that list # to see if any have subsequently been updated. This is done so that @@ -1076,7 +1104,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): clause, ) - def _is_local_host_in_room_ignoring_users_txn(txn): + def _is_local_host_in_room_ignoring_users_txn( + txn: LoggingTransaction, + ) -> bool: txn.execute(sql, (room_id, Membership.JOIN, *args)) return bool(txn.fetchone()) @@ -1110,15 +1140,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): where_clause="forgotten = 1", ) - async def _background_add_membership_profile(self, progress, batch_size): + async def _background_add_membership_profile( + self, progress: JsonDict, batch_size: int + ) -> int: target_min_stream_id = progress.get( - "target_min_stream_id_inclusive", self._min_stream_order_on_start + "target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined] ) max_stream_id = progress.get( - "max_stream_id_exclusive", self._stream_order_on_start + 1 + "max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined] ) - def add_membership_profile_txn(txn): + def add_membership_profile_txn(txn: LoggingTransaction) -> int: sql = """ SELECT stream_ordering, event_id, events.room_id, event_json.json FROM events @@ -1182,13 +1214,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): return result - async def _background_current_state_membership(self, progress, batch_size): + async def _background_current_state_membership( + self, progress: JsonDict, batch_size: int + ) -> int: """Update the new membership column on current_state_events. This works by iterating over all rooms in alphebetical order. """ - def _background_current_state_membership_txn(txn, last_processed_room): + def _background_current_state_membership_txn( + txn: LoggingTransaction, last_processed_room: str + ) -> Tuple[int, bool]: processed = 0 while processed < batch_size: txn.execute( @@ -1242,7 +1278,11 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): return row_count -class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): +class RoomMemberStore( + RoomMemberWorkerStore, + RoomMemberBackgroundUpdateStore, + CacheInvalidationWorkerStore, +): def __init__( self, database: DatabasePool, @@ -1254,7 +1294,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): async def forget(self, user_id: str, room_id: str) -> None: """Indicate that user_id wishes to discard history for room_id.""" - def f(txn): + def f(txn: LoggingTransaction) -> None: sql = ( "UPDATE" " room_memberships" @@ -1288,5 +1328,5 @@ class _JoinedHostsCache: # equal to anything else). state_group: Union[object, int] = attr.Factory(object) - def __len__(self): + def __len__(self) -> int: return sum(len(v) for v in self.hosts_to_joined_users.values()) |