summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/censor_events.py8
-rw-r--r--synapse/storage/databases/main/client_ips.py2
-rw-r--r--synapse/storage/databases/main/events.py34
-rw-r--r--synapse/storage/databases/main/filtering.py8
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py36
-rw-r--r--synapse/storage/databases/main/push_rule.py8
-rw-r--r--synapse/storage/databases/main/pusher.py10
-rw-r--r--synapse/storage/databases/main/registration.py17
-rw-r--r--synapse/storage/databases/main/room.py8
-rw-r--r--synapse/storage/databases/main/room_batch.py6
-rw-r--r--synapse/storage/databases/main/search.py4
-rw-r--r--synapse/storage/databases/main/user_directory.py81
-rw-r--r--synapse/storage/prepare_database.py6
-rw-r--r--synapse/storage/schema/__init__.py6
-rw-r--r--synapse/storage/util/id_generators.py143
-rw-r--r--synapse/storage/util/sequence.py6
16 files changed, 237 insertions, 146 deletions
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py

index 6305414e3d..eee07227ef 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py
@@ -36,7 +36,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase if ( hs.config.worker.run_background_tasks - and self.hs.config.redaction_retention_period is not None + and self.hs.config.server.redaction_retention_period is not None ): hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000) @@ -48,7 +48,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase By censor we mean update the event_json table with the redacted event. """ - if self.hs.config.redaction_retention_period is None: + if self.hs.config.server.redaction_retention_period is None: return if not ( @@ -60,7 +60,9 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase # created. return - before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period + before_ts = ( + self._clock.time_msec() - self.hs.config.server.redaction_retention_period + ) # We fetch all redactions that: # 1. point to an event we have, diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 5f611d7b09..400831282a 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py
@@ -353,7 +353,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self.user_ips_max_age = hs.config.user_ips_max_age + self.user_ips_max_age = hs.config.server.user_ips_max_age if hs.config.worker.run_background_tasks and self.user_ips_max_age: self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 584f818ff3..19f55c19c5 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -104,7 +104,7 @@ class PersistEventsStore: self._clock = hs.get_clock() self._instance_name = hs.get_instance_name() - self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id # Ideally we'd move these ID gens here, unfortunately some other ID @@ -1276,13 +1276,6 @@ class PersistEventsStore: logger.exception("") raise - # update the stored internal_metadata to update the "outlier" flag. - # TODO: This is unused as of Synapse 1.31. Remove it once we are happy - # to drop backwards-compatibility with 1.30. - metadata_json = json_encoder.encode(event.internal_metadata.get_dict()) - sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?" - txn.execute(sql, (metadata_json, event.event_id)) - # Add an entry to the ex_outlier_stream table to replicate the # change in outlier status to our workers. stream_order = event.internal_metadata.stream_ordering @@ -1327,19 +1320,6 @@ class PersistEventsStore: d.pop("redacted_because", None) return d - def get_internal_metadata(event): - im = event.internal_metadata.get_dict() - - # temporary hack for database compatibility with Synapse 1.30 and earlier: - # store the `outlier` flag inside the internal_metadata json as well as in - # the `events` table, so that if anyone rolls back to an older Synapse, - # things keep working. This can be removed once we are happy to drop support - # for that - if event.internal_metadata.is_outlier(): - im["outlier"] = True - - return im - self.db_pool.simple_insert_many_txn( txn, table="event_json", @@ -1348,7 +1328,7 @@ class PersistEventsStore: "event_id": event.event_id, "room_id": event.room_id, "internal_metadata": json_encoder.encode( - get_internal_metadata(event) + event.internal_metadata.get_dict() ), "json": json_encoder.encode(event_dict(event)), "format_version": event.format_version, @@ -1783,9 +1763,8 @@ class PersistEventsStore: retcol="creator", allow_none=True, ) - if ( - not room_version.msc2716_historical - or not self.hs.config.experimental.msc2716_enabled + if not room_version.msc2716_historical and ( + not self.hs.config.experimental.msc2716_enabled or event.sender != room_creator ): return @@ -1845,9 +1824,8 @@ class PersistEventsStore: retcol="creator", allow_none=True, ) - if ( - not room_version.msc2716_historical - or not self.hs.config.experimental.msc2716_enabled + if not room_version.msc2716_historical and ( + not self.hs.config.experimental.msc2716_enabled or event.sender != room_creator ): return diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index bb244a03c0..434986fa64 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py
@@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError @@ -22,7 +24,9 @@ from synapse.util.caches.descriptors import cached class FilteringStore(SQLBaseStore): @cached(num_args=2) - async def get_user_filter(self, user_localpart, filter_id): + async def get_user_filter( + self, user_localpart: str, filter_id: Union[int, str] + ) -> JsonDict: # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # with a coherent error message rather than 500 M_UNKNOWN. try: @@ -40,7 +44,7 @@ class FilteringStore(SQLBaseStore): return db_to_json(def_json) - async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str: + async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int: def_json = encode_canonical_json(user_filter) # Need an atomic transaction to SELECT the maximal ID so far then diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index b76ee51a9b..ec4d47a560 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -32,8 +32,8 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs - self._limit_usage_by_mau = hs.config.limit_usage_by_mau - self._max_mau_value = hs.config.max_mau_value + self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau + self._max_mau_value = hs.config.server.max_mau_value @cached(num_args=0) async def get_monthly_active_count(self) -> int: @@ -96,8 +96,8 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """ users = [] - for tp in self.hs.config.mau_limits_reserved_threepids[ - : self.hs.config.max_mau_value + for tp in self.hs.config.server.mau_limits_reserved_threepids[ + : self.hs.config.server.max_mau_value ]: user_id = await self.hs.get_datastore().get_user_id_by_threepid( tp["medium"], tp["address"] @@ -212,7 +212,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self._mau_stats_only = hs.config.mau_stats_only + self._mau_stats_only = hs.config.server.mau_stats_only # Do not add more reserved users than the total allowable number self.db_pool.new_transaction( @@ -221,7 +221,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): [], [], self._initialise_reserved_users, - hs.config.mau_limits_reserved_threepids[: self._max_mau_value], + hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], ) def _initialise_reserved_users(self, txn, threepids): @@ -354,3 +354,27 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): await self.upsert_monthly_active_user(user_id) elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: await self.upsert_monthly_active_user(user_id) + + async def remove_deactivated_user_from_mau_table(self, user_id: str) -> None: + """ + Removes a deactivated user from the monthly active user + table and resets affected caches. + + Args: + user_id(str): the user_id to remove + """ + + rows_deleted = await self.db_pool.simple_delete( + table="monthly_active_users", + keyvalues={"user_id": user_id}, + desc="simple_delete", + ) + + if rows_deleted != 0: + await self.invalidate_cache_and_stream( + "user_last_seen_monthly_active", (user_id,) + ) + await self.invalidate_cache_and_stream("get_monthly_active_count", ()) + await self.invalidate_cache_and_stream( + "get_monthly_active_count_by_service", () + ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index a7fb8cd848..fc720f5947 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -14,7 +14,7 @@ # limitations under the License. import abc import logging -from typing import List, Tuple, Union +from typing import Dict, List, Tuple, Union from synapse.api.errors import NotFoundError, StoreError from synapse.push.baserules import list_with_base_rules @@ -101,7 +101,9 @@ class PushRulesWorkerStore( prefilled_cache=push_rules_prefill, ) - self._users_new_default_push_rules = hs.config.users_new_default_push_rules + self._users_new_default_push_rules = ( + hs.config.server.users_new_default_push_rules + ) @abc.abstractmethod def get_max_push_rules_stream_id(self): @@ -137,7 +139,7 @@ class PushRulesWorkerStore( return _load_rules(rows, enabled_map, use_new_defaults) @cached(max_entries=5000) - async def get_push_rules_enabled_for_user(self, user_id): + async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index a93caae8d0..b73ce53c91 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py
@@ -18,8 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, from synapse.push import PusherConfig, ThrottleParams from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool -from synapse.storage.types import Connection +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder @@ -32,7 +31,12 @@ logger = logging.getLogger(__name__) class PusherWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._pushers_id_gen = StreamIdGenerator( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c83089ee63..181841ee06 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.stats import StatsStore -from synapse.storage.types import Connection, Cursor +from synapse.storage.types import Cursor from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID, UserInfo @@ -207,7 +207,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return False now = self._clock.time_msec() - trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 + trial_duration_ms = self.config.server.mau_trial_days * 24 * 60 * 60 * 1000 is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms return is_trial @@ -1710,7 +1710,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): We do this by grandfathering in existing user threepids assuming that they used one of the server configured trusted identity servers. """ - id_servers = set(self.config.trusted_third_party_id_servers) + id_servers = set(self.config.registration.trusted_third_party_id_servers) def _bg_user_threepids_grandfather_txn(txn): sql = """ @@ -1775,10 +1775,17 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) - self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors + self._ignore_unknown_session_error = ( + hs.config.server.request_token_inhibit_3pid_errors + ) self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 118b390e93..d69eaf80ce 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -679,8 +679,8 @@ class RoomWorkerStore(SQLBaseStore): # policy. if not ret: return { - "min_lifetime": self.config.retention_default_min_lifetime, - "max_lifetime": self.config.retention_default_max_lifetime, + "min_lifetime": self.config.server.retention_default_min_lifetime, + "max_lifetime": self.config.server.retention_default_max_lifetime, } row = ret[0] @@ -690,10 +690,10 @@ class RoomWorkerStore(SQLBaseStore): # The default values will be None if no default policy has been defined, or if one # of the attributes is missing from the default policy. if row["min_lifetime"] is None: - row["min_lifetime"] = self.config.retention_default_min_lifetime + row["min_lifetime"] = self.config.server.retention_default_min_lifetime if row["max_lifetime"] is None: - row["max_lifetime"] = self.config.retention_default_max_lifetime + row["max_lifetime"] = self.config.server.retention_default_max_lifetime return row diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py
index a383388757..300a563c9e 100644 --- a/synapse/storage/databases/main/room_batch.py +++ b/synapse/storage/databases/main/room_batch.py
@@ -18,7 +18,9 @@ from synapse.storage._base import SQLBaseStore class RoomBatchStore(SQLBaseStore): - async def get_insertion_event_by_batch_id(self, batch_id: str) -> Optional[str]: + async def get_insertion_event_by_batch_id( + self, room_id: str, batch_id: str + ) -> Optional[str]: """Retrieve a insertion event ID. Args: @@ -30,7 +32,7 @@ class RoomBatchStore(SQLBaseStore): """ return await self.db_pool.simple_select_one_onecol( table="insertion_events", - keyvalues={"next_batch_id": batch_id}, + keyvalues={"room_id": room_id, "next_batch_id": batch_id}, retcol="event_id", allow_none=True, ) diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 9eb74a81a0..25df8758bd 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py
@@ -51,7 +51,7 @@ class SearchWorkerStore(SQLBaseStore): txn: entries: entries to be added to the table """ - if not self.hs.config.enable_search: + if not self.hs.config.server.enable_search: return if isinstance(self.database_engine, PostgresEngine): sql = ( @@ -105,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - if not hs.config.enable_search: + if not hs.config.server.enable_search: return self.db_pool.updates.register_background_update_handler( diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 90d65edc42..5c713a732e 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py
@@ -40,12 +40,10 @@ from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) - TEMP_TABLE = "_temp_populate_user_directory" class UserDirectoryBackgroundUpdateStore(StateDeltasStore): - # How many records do we calculate before sending it to # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 @@ -230,38 +228,49 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): is_in_room = await self.is_host_joined(room_id, self.server_name) if is_in_room: - is_public = await self.is_room_world_readable_or_publicly_joinable( - room_id - ) - users_with_profile = await self.get_users_in_room_with_profiles(room_id) + # Throw away users excluded from the directory. + users_with_profile = { + user_id: profile + for user_id, profile in users_with_profile.items() + if not self.hs.is_mine_id(user_id) + or await self.should_include_local_user_in_dir(user_id) + } - # Update each user in the user directory. + # Upsert a user_directory record for each remote user we see. for user_id, profile in users_with_profile.items(): + # Local users are processed separately in + # `_populate_user_directory_users`; there we can read from + # the `profiles` table to ensure we don't leak their per-room + # profiles. It also means we write local users to this table + # exactly once, rather than once for every room they're in. + if self.hs.is_mine_id(user_id): + continue + # TODO `users_with_profile` above reads from the `user_directory` + # table, meaning that `profile` is bespoke to this room. + # and this leaks remote users' per-room profiles to the user directory. await self.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) - to_insert = set() - + # Now update the room sharing tables to include this room. + is_public = await self.is_room_world_readable_or_publicly_joinable( + room_id + ) if is_public: - for user_id in users_with_profile: - if self.get_if_app_services_interested_in_user(user_id): - continue - - to_insert.add(user_id) - - if to_insert: - await self.add_users_in_public_rooms(room_id, to_insert) - to_insert.clear() + if users_with_profile: + await self.add_users_in_public_rooms( + room_id, users_with_profile.keys() + ) else: + to_insert = set() for user_id in users_with_profile: + # We want the set of pairs (L, M) where L and M are + # in `users_with_profile` and L is local. + # Do so by looking for the local user L first. if not self.hs.is_mine_id(user_id): continue - if self.get_if_app_services_interested_in_user(user_id): - continue - for other_user_id in users_with_profile: if user_id == other_user_id: continue @@ -349,10 +358,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) for user_id in users_to_work_on: - profile = await self.get_profileinfo(get_localpart_from_id(user_id)) - await self.update_profile_in_user_dir( - user_id, profile.display_name, profile.avatar_url - ) + if await self.should_include_local_user_in_dir(user_id): + profile = await self.get_profileinfo(get_localpart_from_id(user_id)) + await self.update_profile_in_user_dir( + user_id, profile.display_name, profile.avatar_url + ) # We've finished processing a user. Delete it from the table. await self.db_pool.simple_delete_one( @@ -369,6 +379,24 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return len(users_to_work_on) + async def should_include_local_user_in_dir(self, user: str) -> bool: + """Certain classes of local user are omitted from the user directory. + Is this user one of them? + """ + # App service users aren't usually contactable, so exclude them. + if self.get_if_app_services_interested_in_user(user): + # TODO we might want to make this configurable for each app service + return False + + # Support users are for diagnostics and should not appear in the user directory. + if await self.is_support_user(user): + return False + + # Deactivated users aren't contactable, so should not appear in the user directory. + if await self.get_user_deactivated_status(user): + return False + return True + async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool: """Check if the room is either world_readable or publically joinable""" @@ -527,7 +555,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): desc="get_user_in_directory", ) - async def update_user_directory_stream_pos(self, stream_id: int) -> None: + async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None: await self.db_pool.simple_update_one( table="user_directory_stream_pos", keyvalues={}, @@ -537,7 +565,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): - # How many records do we calculate before sending it to # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index f31880b8ec..11ca47ea28 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py
@@ -366,7 +366,7 @@ def _upgrade_existing_database( + "new for the server to understand" ) - # some of the deltas assume that config.server_name is set correctly, so now + # some of the deltas assume that server_name is set correctly, so now # is a good time to run the sanity check. if not is_empty and "main" in databases: from synapse.storage.databases.main import check_database_before_upgrade @@ -487,6 +487,10 @@ def _upgrade_existing_database( spec = importlib.util.spec_from_file_location( module_name, absolute_path ) + if spec is None: + raise RuntimeError( + f"Could not build a module spec for {module_name} at {absolute_path}" + ) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # type: ignore diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 573e05a482..1aee741a8b 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py
@@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# When updating these values, please leave a short summary of the changes below. - -SCHEMA_VERSION = 64 +SCHEMA_VERSION = 64 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -46,7 +44,7 @@ Changes in SCHEMA_VERSION = 64: """ -SCHEMA_COMPAT_VERSION = 59 +SCHEMA_COMPAT_VERSION = 60 # 60: "outlier" not in internal_metadata. """Limit on how far the synapse codebase can be rolled back without breaking db compat This value is stored in the database, and checked on startup. If the value in the diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 6f7cbe40f4..852bd79fee 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py
@@ -16,42 +16,62 @@ import logging import threading from collections import OrderedDict from contextlib import contextmanager -from typing import Dict, Iterable, List, Optional, Set, Tuple, Union +from types import TracebackType +from typing import ( + AsyncContextManager, + ContextManager, + Dict, + Generator, + Generic, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) import attr from sortedcontainers import SortedSet from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.types import Cursor from synapse.storage.util.sequence import PostgresSequenceGenerator logger = logging.getLogger(__name__) +T = TypeVar("T") + + class IdGenerator: - def __init__(self, db_conn, table, column): + def __init__( + self, + db_conn: LoggingDatabaseConnection, + table: str, + column: str, + ): self._lock = threading.Lock() self._next_id = _load_current_id(db_conn, table, column) - def get_next(self): + def get_next(self) -> int: with self._lock: self._next_id += 1 return self._next_id -def _load_current_id(db_conn, table, column, step=1): - """ - - Args: - db_conn (object): - table (str): - column (str): - step (int): - - Returns: - int - """ +def _load_current_id( + db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1 +) -> int: # debug logging for https://github.com/matrix-org/synapse/issues/7968 logger.info("initialising stream generator for %s(%s)", table, column) cur = db_conn.cursor(txn_name="_load_current_id") @@ -59,7 +79,9 @@ def _load_current_id(db_conn, table, column, step=1): cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) else: cur.execute("SELECT MIN(%s) FROM %s" % (column, table)) - (val,) = cur.fetchone() + result = cur.fetchone() + assert result is not None + (val,) = result cur.close() current_id = int(val) if val else step return (max if step > 0 else min)(current_id, step) @@ -93,16 +115,16 @@ class StreamIdGenerator: def __init__( self, - db_conn, - table, - column, + db_conn: LoggingDatabaseConnection, + table: str, + column: str, extra_tables: Iterable[Tuple[str, str]] = (), - step=1, - ): + step: int = 1, + ) -> None: assert step != 0 self._lock = threading.Lock() - self._step = step - self._current = _load_current_id(db_conn, table, column, step) + self._step: int = step + self._current: int = _load_current_id(db_conn, table, column, step) for table, column in extra_tables: self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) @@ -115,7 +137,7 @@ class StreamIdGenerator: # The key and values are the same, but we never look at the values. self._unfinished_ids: OrderedDict[int, int] = OrderedDict() - def get_next(self): + def get_next(self) -> AsyncContextManager[int]: """ Usage: async with stream_id_gen.get_next() as stream_id: @@ -128,7 +150,7 @@ class StreamIdGenerator: self._unfinished_ids[next_id] = next_id @contextmanager - def manager(): + def manager() -> Generator[int, None, None]: try: yield next_id finally: @@ -137,7 +159,7 @@ class StreamIdGenerator: return _AsyncCtxManagerWrapper(manager()) - def get_next_mult(self, n): + def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: """ Usage: async with stream_id_gen.get_next(n) as stream_ids: @@ -155,7 +177,7 @@ class StreamIdGenerator: self._unfinished_ids[next_id] = next_id @contextmanager - def manager(): + def manager() -> Generator[Sequence[int], None, None]: try: yield next_ids finally: @@ -215,7 +237,7 @@ class MultiWriterIdGenerator: def __init__( self, - db_conn, + db_conn: LoggingDatabaseConnection, db: DatabasePool, stream_name: str, instance_name: str, @@ -223,7 +245,7 @@ class MultiWriterIdGenerator: sequence_name: str, writers: List[str], positive: bool = True, - ): + ) -> None: self._db = db self._stream_name = stream_name self._instance_name = instance_name @@ -285,9 +307,9 @@ class MultiWriterIdGenerator: def _load_current_ids( self, - db_conn, + db_conn: LoggingDatabaseConnection, tables: List[Tuple[str, str, str]], - ): + ) -> None: cur = db_conn.cursor(txn_name="_load_current_ids") # Load the current positions of all writers for the stream. @@ -335,7 +357,9 @@ class MultiWriterIdGenerator: "agg": "MAX" if self._positive else "-MIN", } cur.execute(sql) - (stream_id,) = cur.fetchone() + result = cur.fetchone() + assert result is not None + (stream_id,) = result max_stream_id = max(max_stream_id, stream_id) @@ -354,7 +378,7 @@ class MultiWriterIdGenerator: self._persisted_upto_position = min_stream_id - rows = [] + rows: List[Tuple[str, int]] = [] for table, instance_column, id_column in tables: sql = """ SELECT %(instance)s, %(id)s FROM %(table)s @@ -367,7 +391,8 @@ class MultiWriterIdGenerator: } cur.execute(sql, (min_stream_id * self._return_factor,)) - rows.extend(cur) + # Cast safety: this corresponds to the types returned by the query above. + rows.extend(cast(Iterable[Tuple[str, int]], cur)) # Sort so that we handle rows in order for each instance. rows.sort() @@ -385,13 +410,13 @@ class MultiWriterIdGenerator: cur.close() - def _load_next_id_txn(self, txn) -> int: + def _load_next_id_txn(self, txn: Cursor) -> int: return self._sequence_gen.get_next_id_txn(txn) - def _load_next_mult_id_txn(self, txn, n: int) -> List[int]: + def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]: return self._sequence_gen.get_next_mult_txn(txn, n) - def get_next(self): + def get_next(self) -> AsyncContextManager[int]: """ Usage: async with stream_id_gen.get_next() as stream_id: @@ -403,9 +428,12 @@ class MultiWriterIdGenerator: if self._writers and self._instance_name not in self._writers: raise Exception("Tried to allocate stream ID on non-writer") - return _MultiWriterCtxManager(self) + # Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids, + # controls the return type. If `None` or omitted, the context manager yields + # a single integer stream_id; otherwise it yields a list of stream_ids. + return cast(AsyncContextManager[int], _MultiWriterCtxManager(self)) - def get_next_mult(self, n: int): + def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]: """ Usage: async with stream_id_gen.get_next_mult(5) as stream_ids: @@ -417,9 +445,10 @@ class MultiWriterIdGenerator: if self._writers and self._instance_name not in self._writers: raise Exception("Tried to allocate stream ID on non-writer") - return _MultiWriterCtxManager(self, n) + # Cast safety: see get_next. + return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n)) - def get_next_txn(self, txn: LoggingTransaction): + def get_next_txn(self, txn: LoggingTransaction) -> int: """ Usage: @@ -457,7 +486,7 @@ class MultiWriterIdGenerator: return self._return_factor * next_id - def _mark_id_as_finished(self, next_id: int): + def _mark_id_as_finished(self, next_id: int) -> None: """The ID has finished being processed so we should advance the current position if possible. """ @@ -534,7 +563,7 @@ class MultiWriterIdGenerator: for name, i in self._current_positions.items() } - def advance(self, instance_name: str, new_id: int): + def advance(self, instance_name: str, new_id: int) -> None: """Advance the position of the named writer to the given ID, if greater than existing entry. """ @@ -560,7 +589,7 @@ class MultiWriterIdGenerator: with self._lock: return self._return_factor * self._persisted_upto_position - def _add_persisted_position(self, new_id: int): + def _add_persisted_position(self, new_id: int) -> None: """Record that we have persisted a position. This is used to keep the `_current_positions` up to date. @@ -606,7 +635,7 @@ class MultiWriterIdGenerator: # do. break - def _update_stream_positions_table_txn(self, txn: Cursor): + def _update_stream_positions_table_txn(self, txn: Cursor) -> None: """Update the `stream_positions` table with newly persisted position.""" if not self._writers: @@ -628,20 +657,25 @@ class MultiWriterIdGenerator: txn.execute(sql, (self._stream_name, self._instance_name, pos)) -@attr.s(slots=True) -class _AsyncCtxManagerWrapper: +@attr.s(frozen=True, auto_attribs=True) +class _AsyncCtxManagerWrapper(Generic[T]): """Helper class to convert a plain context manager to an async one. This is mainly useful if you have a plain context manager but the interface requires an async one. """ - inner = attr.ib() + inner: ContextManager[T] - async def __aenter__(self): + async def __aenter__(self) -> T: return self.inner.__enter__() - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: return self.inner.__exit__(exc_type, exc, tb) @@ -671,7 +705,12 @@ class _MultiWriterCtxManager: else: return [i * self.id_gen._return_factor for i in self.stream_ids] - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> bool: for i in self.stream_ids: self.id_gen._mark_id_as_finished(i) diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index bb33e04fb1..75268cbe15 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py
@@ -81,7 +81,7 @@ class SequenceGenerator(metaclass=abc.ABCMeta): id_column: str, stream_name: Optional[str] = None, positive: bool = True, - ): + ) -> None: """Should be called during start up to test that the current value of the sequence is greater than or equal to the maximum ID in the table. @@ -122,7 +122,7 @@ class PostgresSequenceGenerator(SequenceGenerator): id_column: str, stream_name: Optional[str] = None, positive: bool = True, - ): + ) -> None: """See SequenceGenerator.check_consistency for docstring.""" txn = db_conn.cursor(txn_name="sequence.check_consistency") @@ -244,7 +244,7 @@ class LocalSequenceGenerator(SequenceGenerator): id_column: str, stream_name: Optional[str] = None, positive: bool = True, - ): + ) -> None: # There is nothing to do for in memory sequences pass