diff options
Diffstat (limited to 'synapse')
71 files changed, 622 insertions, 643 deletions
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 1dcc289df3..a533cad5ae 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -1040,10 +1040,10 @@ class Porter: return done, remaining + done async def _setup_state_group_id_seq(self) -> None: - curr_id: Optional[ - int - ] = await self.sqlite_store.db_pool.simple_select_one_onecol( - table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True + curr_id: Optional[int] = ( + await self.sqlite_store.db_pool.simple_select_one_onecol( + table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True + ) ) if not curr_id: @@ -1132,13 +1132,13 @@ class Porter: ) async def _setup_auth_chain_sequence(self) -> None: - curr_chain_id: Optional[ - int - ] = await self.sqlite_store.db_pool.simple_select_one_onecol( - table="event_auth_chains", - keyvalues={}, - retcol="MAX(chain_id)", - allow_none=True, + curr_chain_id: Optional[int] = ( + await self.sqlite_store.db_pool.simple_select_one_onecol( + table="event_auth_chains", + keyvalues={}, + retcol="MAX(chain_id)", + allow_none=True, + ) ) def r(txn: LoggingTransaction) -> None: diff --git a/synapse/api/constants.py b/synapse/api/constants.py index f3d2c8073d..d25aff98ff 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -43,7 +43,6 @@ MAIN_TIMELINE: Final = "main" class Membership: - """Represents the membership states of a user in a room.""" INVITE: Final = "invite" diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index 7ff8ad2d55..fbc1d58ecb 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -370,9 +370,11 @@ class RoomVersionCapability: MSC3244_CAPABILITIES = { cap.identifier: { - "preferred": cap.preferred_version.identifier - if cap.preferred_version is not None - else None, + "preferred": ( + cap.preferred_version.identifier + if cap.preferred_version is not None + else None + ), "support": [ v.identifier for v in KNOWN_ROOM_VERSIONS.values() diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b241dbf627..8a545a86c1 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -188,9 +188,9 @@ class SynapseHomeServer(HomeServer): PasswordResetSubmitTokenResource, ) - resources[ - "/_synapse/client/password_reset/email/submit_token" - ] = PasswordResetSubmitTokenResource(self) + resources["/_synapse/client/password_reset/email/submit_token"] = ( + PasswordResetSubmitTokenResource(self) + ) if name == "consent": from synapse.rest.consent.consent_resource import ConsentResource diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 34fa2bb655..19322471dc 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -362,16 +362,16 @@ class ApplicationServiceApi(SimpleHttpClient): # TODO: Update to stable prefixes once MSC3202 completes FCP merge if service.msc3202_transaction_extensions: if one_time_keys_count: - body[ - "org.matrix.msc3202.device_one_time_key_counts" - ] = one_time_keys_count - body[ - "org.matrix.msc3202.device_one_time_keys_count" - ] = one_time_keys_count + body["org.matrix.msc3202.device_one_time_key_counts"] = ( + one_time_keys_count + ) + body["org.matrix.msc3202.device_one_time_keys_count"] = ( + one_time_keys_count + ) if unused_fallback_keys: - body[ - "org.matrix.msc3202.device_unused_fallback_key_types" - ] = unused_fallback_keys + body["org.matrix.msc3202.device_unused_fallback_key_types"] = ( + unused_fallback_keys + ) if device_list_summary: body["org.matrix.msc3202.device_lists"] = { "changed": list(device_list_summary.changed), diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 3fe0f050cd..c7f3e6d35e 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -171,9 +171,9 @@ class RegistrationConfig(Config): refreshable_access_token_lifetime = self.parse_duration( refreshable_access_token_lifetime ) - self.refreshable_access_token_lifetime: Optional[ - int - ] = refreshable_access_token_lifetime + self.refreshable_access_token_lifetime: Optional[int] = ( + refreshable_access_token_lifetime + ) if ( self.session_lifetime is not None diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 4655882b4b..1645470499 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -199,9 +199,9 @@ class ContentRepositoryConfig(Config): provider_config["module"] == "file_system" or provider_config["module"] == "synapse.rest.media.v1.storage_provider" ): - provider_config[ - "module" - ] = "synapse.media.storage_provider.FileStorageProviderBackend" + provider_config["module"] = ( + "synapse.media.storage_provider.FileStorageProviderBackend" + ) provider_class, parsed_config = load_module( provider_config, ("media_storage_providers", "<item %i>" % i) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c8b06f760e..f5abcde2db 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -88,8 +88,7 @@ class _EventSourceStore(Protocol): redact_behaviour: EventRedactBehaviour, get_prev_content: bool = False, allow_rejected: bool = False, - ) -> Dict[str, "EventBase"]: - ... + ) -> Dict[str, "EventBase"]: ... def validate_event_for_room_version(event: "EventBase") -> None: diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 7ec696c6c0..36e0f47e51 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -93,16 +93,14 @@ class DictProperty(Generic[T]): self, instance: Literal[None], owner: Optional[Type[_DictPropertyInstance]] = None, - ) -> "DictProperty": - ... + ) -> "DictProperty": ... @overload def __get__( self, instance: _DictPropertyInstance, owner: Optional[Type[_DictPropertyInstance]] = None, - ) -> T: - ... + ) -> T: ... def __get__( self, @@ -161,16 +159,14 @@ class DefaultDictProperty(DictProperty, Generic[T]): self, instance: Literal[None], owner: Optional[Type[_DictPropertyInstance]] = None, - ) -> "DefaultDictProperty": - ... + ) -> "DefaultDictProperty": ... @overload def __get__( self, instance: _DictPropertyInstance, owner: Optional[Type[_DictPropertyInstance]] = None, - ) -> T: - ... + ) -> T: ... def __get__( self, diff --git a/synapse/events/utils.py b/synapse/events/utils.py index cc52d0d1e9..e0613d0dbc 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -612,9 +612,9 @@ class EventClientSerializer: serialized_aggregations = {} if event_aggregations.references: - serialized_aggregations[ - RelationTypes.REFERENCE - ] = event_aggregations.references + serialized_aggregations[RelationTypes.REFERENCE] = ( + event_aggregations.references + ) if event_aggregations.replace: # Include information about it in the relations dict. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index dc8cd5ec9a..65d3a661fe 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -169,9 +169,9 @@ class FederationServer(FederationBase): # We cache responses to state queries, as they take a while and often # come in waves. - self._state_resp_cache: ResponseCache[ - Tuple[str, Optional[str]] - ] = ResponseCache(hs.get_clock(), "state_resp", timeout_ms=30000) + self._state_resp_cache: ResponseCache[Tuple[str, Optional[str]]] = ( + ResponseCache(hs.get_clock(), "state_resp", timeout_ms=30000) + ) self._state_ids_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache( hs.get_clock(), "state_ids_resp", timeout_ms=30000 ) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index e9a2386a5c..b5c9fcff7c 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -88,9 +88,9 @@ class FederationRemoteSendQueue(AbstractFederationSender): # Stores the destinations we need to explicitly send presence to about a # given user. # Stream position -> (user_id, destinations) - self.presence_destinations: SortedDict[ - int, Tuple[str, Iterable[str]] - ] = SortedDict() + self.presence_destinations: SortedDict[int, Tuple[str, Iterable[str]]] = ( + SortedDict() + ) # (destination, key) -> EDU self.keyed_edu: Dict[Tuple[str, tuple], Edu] = {} diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py index 37cc3d3ff5..89e944bc17 100644 --- a/synapse/handlers/account.py +++ b/synapse/handlers/account.py @@ -118,10 +118,10 @@ class AccountHandler: } if self._use_account_validity_in_account_status: - status[ - "org.matrix.expired" - ] = await self._account_validity_handler.is_user_expired( - user_id.to_string() + status["org.matrix.expired"] = ( + await self._account_validity_handler.is_user_expired( + user_id.to_string() + ) ) return status diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f233f1b034..a1fab99f6b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -2185,7 +2185,7 @@ class PasswordAuthProvider: # result is always the right type, but as it is 3rd party code it might not be if not isinstance(result, tuple) or len(result) != 2: - logger.warning( + logger.warning( # type: ignore[unreachable] "Wrong type returned by module API callback %s: %s, expected" " Optional[Tuple[str, Optional[Callable]]]", callback, @@ -2248,7 +2248,7 @@ class PasswordAuthProvider: # result is always the right type, but as it is 3rd party code it might not be if not isinstance(result, tuple) or len(result) != 2: - logger.warning( + logger.warning( # type: ignore[unreachable] "Wrong type returned by module API callback %s: %s, expected" " Optional[Tuple[str, Optional[Callable]]]", callback, diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 5f3dc30b63..ad2b0f5fcc 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -265,9 +265,9 @@ class DirectoryHandler: async def get_association(self, room_alias: RoomAlias) -> JsonDict: room_id = None if self.hs.is_mine(room_alias): - result: Optional[ - RoomAliasMapping - ] = await self.get_association_from_room_alias(room_alias) + result: Optional[RoomAliasMapping] = ( + await self.get_association_from_room_alias(room_alias) + ) if result: room_id = result.room_id diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2b7aad5b58..299588e476 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1001,11 +1001,11 @@ class FederationHandler: ) if include_auth_user_id: - event_content[ - EventContentFields.AUTHORISING_USER - ] = await self._event_auth_handler.get_user_which_could_invite( - room_id, - state_ids, + event_content[EventContentFields.AUTHORISING_USER] = ( + await self._event_auth_handler.get_user_which_could_invite( + room_id, + state_ids, + ) ) builder = self.event_builder_factory.for_room_version( diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 83f6a25981..c85deaed56 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1367,9 +1367,9 @@ class FederationEventHandler: ) if remote_event.is_state() and remote_event.rejected_reason is None: - state_map[ - (remote_event.type, remote_event.state_key) - ] = remote_event.event_id + state_map[(remote_event.type, remote_event.state_key)] = ( + remote_event.event_id + ) return state_map diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7e5bb97f2a..0ce6eeee15 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1654,9 +1654,9 @@ class EventCreationHandler: expiry_ms=60 * 60 * 1000, ) - self._external_cache_joined_hosts_updates[ - state_entry.state_group - ] = None + self._external_cache_joined_hosts_updates[state_entry.state_group] = ( + None + ) async def _validate_canonical_alias( self, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 21d3c71d8e..37ee625f71 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -493,9 +493,9 @@ class WorkerPresenceHandler(BasePresenceHandler): # The number of ongoing syncs on this process, by (user ID, device ID). # Empty if _presence_enabled is false. - self._user_device_to_num_current_syncs: Dict[ - Tuple[str, Optional[str]], int - ] = {} + self._user_device_to_num_current_syncs: Dict[Tuple[str, Optional[str]], int] = ( + {} + ) self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() @@ -818,9 +818,9 @@ class PresenceHandler(BasePresenceHandler): # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self._user_device_to_num_current_syncs: Dict[ - Tuple[str, Optional[str]], int - ] = {} + self._user_device_to_num_current_syncs: Dict[Tuple[str, Optional[str]], int] = ( + {} + ) # Keeps track of the number of *ongoing* syncs on other processes. # diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 279d393a5a..e51e282a9f 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -320,9 +320,9 @@ class ProfileHandler: server_name = host if self._is_mine_server_name(server_name): - media_info: Optional[ - Union[LocalMedia, RemoteMedia] - ] = await self.store.get_local_media(media_id) + media_info: Optional[Union[LocalMedia, RemoteMedia]] = ( + await self.store.get_local_media(media_id) + ) else: media_info = await self.store.get_cached_remote_media(server_name, media_id) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 828a4b4cbd..931ac0c813 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -188,13 +188,13 @@ class RelationsHandler: if include_original_event: # Do not bundle aggregations when retrieving the original event because # we want the content before relations are applied to it. - return_value[ - "original_event" - ] = await self._event_serializer.serialize_event( - event, - now, - bundle_aggregations=None, - config=serialize_options, + return_value["original_event"] = ( + await self._event_serializer.serialize_event( + event, + now, + bundle_aggregations=None, + config=serialize_options, + ) ) if next_token: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6b116dce8c..3278426ca3 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -538,10 +538,10 @@ class RoomCreationHandler: # deep-copy the power-levels event before we start modifying it # note that if frozen_dicts are enabled, `power_levels` will be a frozen # dict so we can't just copy.deepcopy it. - initial_state[ - (EventTypes.PowerLevels, "") - ] = power_levels = copy_and_fixup_power_levels_contents( - initial_state[(EventTypes.PowerLevels, "")] + initial_state[(EventTypes.PowerLevels, "")] = power_levels = ( + copy_and_fixup_power_levels_contents( + initial_state[(EventTypes.PowerLevels, "")] + ) ) # Resolve the minimum power level required to send any state event @@ -1362,9 +1362,11 @@ class RoomCreationHandler: visibility = room_config.get("visibility", "private") preset_name = room_config.get( "preset", - RoomCreationPreset.PRIVATE_CHAT - if visibility == "private" - else RoomCreationPreset.PUBLIC_CHAT, + ( + RoomCreationPreset.PRIVATE_CHAT + if visibility == "private" + else RoomCreationPreset.PUBLIC_CHAT + ), ) try: preset_config = self._presets_dict[preset_name] diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index d238c40bcf..9e9f6cd062 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1236,11 +1236,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # If this is going to be a local join, additional information must # be included in the event content in order to efficiently validate # the event. - content[ - EventContentFields.AUTHORISING_USER - ] = await self.event_auth_handler.get_user_which_could_invite( - room_id, - state_before_join, + content[EventContentFields.AUTHORISING_USER] = ( + await self.event_auth_handler.get_user_which_could_invite( + room_id, + state_before_join, + ) ) return False, [] diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 437cb5509c..8e39e76c97 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -150,7 +150,7 @@ class UserAttributes: display_name: Optional[str] = None picture: Optional[str] = None # mypy thinks these are incompatible for some reason. - emails: StrCollection = attr.Factory(list) # type: ignore[assignment] + emails: StrCollection = attr.Factory(list) @attr.s(slots=True, auto_attribs=True) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9122a79b4c..0aedb37f16 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1014,30 +1014,6 @@ class SyncHandler: if event.is_state(): timeline_state[(event.type, event.state_key)] = event.event_id - if full_state: - # always make sure we LL ourselves so we know we're in the room - # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209 - # We only need apply this on full state syncs given we disabled - # LL for incr syncs in https://github.com/matrix-org/synapse/pull/3840. - # We don't insert ourselves into `members_to_fetch`, because in some - # rare cases (an empty event batch with a now_token after the user's - # leave in a partial state room which another local user has - # joined), the room state will be missing our membership and there - # is no guarantee that our membership will be in the auth events of - # timeline events when the room is partial stated. - state_filter = StateFilter.from_lazy_load_member_list( - members_to_fetch.union((sync_config.user.to_string(),)) - ) - else: - state_filter = StateFilter.from_lazy_load_member_list( - members_to_fetch - ) - - # We are happy to use partial state to compute the `/sync` response. - # Since partial state may not include the lazy-loaded memberships we - # require, we fix up the state response afterwards with memberships from - # auth events. - await_full_state = False else: timeline_state = { (event.type, event.state_key): event.event_id @@ -1045,9 +1021,6 @@ class SyncHandler: if event.is_state() } - state_filter = StateFilter.all() - await_full_state = True - # Now calculate the state to return in the sync response for the room. # This is more or less the change in state between the end of the previous # sync's timeline and the start of the current sync's timeline. @@ -1057,131 +1030,28 @@ class SyncHandler: # whether the room is partial stated *before* fetching it. is_partial_state_room = await self.store.is_partial_state_room(room_id) if full_state: - if batch: - state_at_timeline_end = ( - await self._state_storage_controller.get_state_ids_for_event( - batch.events[-1].event_id, - state_filter=state_filter, - await_full_state=await_full_state, - ) - ) - - state_at_timeline_start = ( - await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, - state_filter=state_filter, - await_full_state=await_full_state, - ) - ) - - else: - state_at_timeline_end = await self.get_state_at( - room_id, - stream_position=now_token, - state_filter=state_filter, - await_full_state=await_full_state, - ) - - state_at_timeline_start = state_at_timeline_end - - state_ids = _calculate_state( - timeline_contains=timeline_state, - timeline_start=state_at_timeline_start, - timeline_end=state_at_timeline_end, - previous_timeline_end={}, - lazy_load_members=lazy_load_members, + state_ids = await self._compute_state_delta_for_full_sync( + room_id, + sync_config.user, + batch, + now_token, + members_to_fetch, + timeline_state, ) - elif batch.limited: - if batch: - state_at_timeline_start = ( - await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, - state_filter=state_filter, - await_full_state=await_full_state, - ) - ) - else: - # We can get here if the user has ignored the senders of all - # the recent events. - state_at_timeline_start = await self.get_state_at( - room_id, - stream_position=now_token, - state_filter=state_filter, - await_full_state=await_full_state, - ) - - # for now, we disable LL for gappy syncs - see - # https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346 - # N.B. this slows down incr syncs as we are now processing way - # more state in the server than if we were LLing. - # - # We still have to filter timeline_start to LL entries (above) in order - # for _calculate_state's LL logic to work, as we have to include LL - # members for timeline senders in case they weren't loaded in the initial - # sync. We do this by (counterintuitively) by filtering timeline_start - # members to just be ones which were timeline senders, which then ensures - # all of the rest get included in the state block (if we need to know - # about them). - state_filter = StateFilter.all() - + else: # If this is an initial sync then full_state should be set, and # that case is handled above. We assert here to ensure that this # is indeed the case. assert since_token is not None - state_at_previous_sync = await self.get_state_at( - room_id, - stream_position=since_token, - state_filter=state_filter, - await_full_state=await_full_state, - ) - if batch: - state_at_timeline_end = ( - await self._state_storage_controller.get_state_ids_for_event( - batch.events[-1].event_id, - state_filter=state_filter, - await_full_state=await_full_state, - ) - ) - else: - # We can get here if the user has ignored the senders of all - # the recent events. - state_at_timeline_end = await self.get_state_at( - room_id, - stream_position=now_token, - state_filter=state_filter, - await_full_state=await_full_state, - ) - - state_ids = _calculate_state( - timeline_contains=timeline_state, - timeline_start=state_at_timeline_start, - timeline_end=state_at_timeline_end, - previous_timeline_end=state_at_previous_sync, - # we have to include LL members in case LL initial sync missed them - lazy_load_members=lazy_load_members, + state_ids = await self._compute_state_delta_for_incremental_sync( + room_id, + batch, + since_token, + now_token, + members_to_fetch, + timeline_state, ) - else: - state_ids = {} - if lazy_load_members: - if members_to_fetch and batch.events: - # We're returning an incremental sync, with no - # "gap" since the previous sync, so normally there would be - # no state to return. - # But we're lazy-loading, so the client might need some more - # member events to understand the events in this timeline. - # So we fish out all the member events corresponding to the - # timeline here, and then dedupe any redundant ones below. - - state_ids = await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, - # we only want members! - state_filter=StateFilter.from_types( - (EventTypes.Member, member) - for member in members_to_fetch - ), - await_full_state=False, - ) # If we only have partial state for the room, `state_ids` may be missing the # memberships we wanted. We attempt to find some by digging through the auth @@ -1245,6 +1115,227 @@ class SyncHandler: if e.type != EventTypes.Aliases # until MSC2261 or alternative solution } + async def _compute_state_delta_for_full_sync( + self, + room_id: str, + syncing_user: UserID, + batch: TimelineBatch, + now_token: StreamToken, + members_to_fetch: Optional[Set[str]], + timeline_state: StateMap[str], + ) -> StateMap[str]: + """Calculate the state events to be included in a full sync response. + + As with `_compute_state_delta_for_incremental_sync`, the result will include + the membership events for the senders of each event in `members_to_fetch`. + + Args: + room_id: The room we are calculating for. + syncing_user: The user that is calling `/sync`. + batch: The timeline batch for the room that will be sent to the user. + now_token: Token of the end of the current batch. + members_to_fetch: If lazy-loading is enabled, the memberships needed for + events in the timeline. + timeline_state: The contribution to the room state from state events in + `batch`. Only contains the last event for any given state key. + + Returns: + A map from (type, state_key) to event_id, for each event that we believe + should be included in the `state` part of the sync response. + """ + if members_to_fetch is not None: + # Lazy-loading of membership events is enabled. + # + # Always make sure we load our own membership event so we know if + # we're in the room, to fix https://github.com/vector-im/riot-web/issues/7209. + # + # We only need apply this on full state syncs given we disabled + # LL for incr syncs in https://github.com/matrix-org/synapse/pull/3840. + # + # We don't insert ourselves into `members_to_fetch`, because in some + # rare cases (an empty event batch with a now_token after the user's + # leave in a partial state room which another local user has + # joined), the room state will be missing our membership and there + # is no guarantee that our membership will be in the auth events of + # timeline events when the room is partial stated. + state_filter = StateFilter.from_lazy_load_member_list( + members_to_fetch.union((syncing_user.to_string(),)) + ) + + # We are happy to use partial state to compute the `/sync` response. + # Since partial state may not include the lazy-loaded memberships we + # require, we fix up the state response afterwards with memberships from + # auth events. + await_full_state = False + lazy_load_members = True + else: + state_filter = StateFilter.all() + await_full_state = True + lazy_load_members = False + + if batch: + state_at_timeline_end = ( + await self._state_storage_controller.get_state_ids_for_event( + batch.events[-1].event_id, + state_filter=state_filter, + await_full_state=await_full_state, + ) + ) + + state_at_timeline_start = ( + await self._state_storage_controller.get_state_ids_for_event( + batch.events[0].event_id, + state_filter=state_filter, + await_full_state=await_full_state, + ) + ) + else: + state_at_timeline_end = await self.get_state_at( + room_id, + stream_position=now_token, + state_filter=state_filter, + await_full_state=await_full_state, + ) + + state_at_timeline_start = state_at_timeline_end + + state_ids = _calculate_state( + timeline_contains=timeline_state, + timeline_start=state_at_timeline_start, + timeline_end=state_at_timeline_end, + previous_timeline_end={}, + lazy_load_members=lazy_load_members, + ) + return state_ids + + async def _compute_state_delta_for_incremental_sync( + self, + room_id: str, + batch: TimelineBatch, + since_token: StreamToken, + now_token: StreamToken, + members_to_fetch: Optional[Set[str]], + timeline_state: StateMap[str], + ) -> StateMap[str]: + """Calculate the state events to be included in an incremental sync response. + + If lazy-loading of membership events is enabled (as indicated by + `members_to_fetch` being not-`None`), the result will include the membership + events for each member in `members_to_fetch`. The caller + (`compute_state_delta`) is responsible for keeping track of which membership + events we have already sent to the client, and hence ripping them out. + + Args: + room_id: The room we are calculating for. + batch: The timeline batch for the room that will be sent to the user. + since_token: Token of the end of the previous batch. + now_token: Token of the end of the current batch. + members_to_fetch: If lazy-loading is enabled, the memberships needed for + events in the timeline. Otherwise, `None`. + timeline_state: The contribution to the room state from state events in + `batch`. Only contains the last event for any given state key. + + Returns: + A map from (type, state_key) to event_id, for each event that we believe + should be included in the `state` part of the sync response. + """ + if members_to_fetch is not None: + # Lazy-loading is enabled. Only return the state that is needed. + state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch) + await_full_state = False + lazy_load_members = True + else: + state_filter = StateFilter.all() + await_full_state = True + lazy_load_members = False + + if batch.limited: + if batch: + state_at_timeline_start = ( + await self._state_storage_controller.get_state_ids_for_event( + batch.events[0].event_id, + state_filter=state_filter, + await_full_state=await_full_state, + ) + ) + else: + # We can get here if the user has ignored the senders of all + # the recent events. + state_at_timeline_start = await self.get_state_at( + room_id, + stream_position=now_token, + state_filter=state_filter, + await_full_state=await_full_state, + ) + + # for now, we disable LL for gappy syncs - see + # https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346 + # N.B. this slows down incr syncs as we are now processing way + # more state in the server than if we were LLing. + # + # We still have to filter timeline_start to LL entries (above) in order + # for _calculate_state's LL logic to work, as we have to include LL + # members for timeline senders in case they weren't loaded in the initial + # sync. We do this by (counterintuitively) by filtering timeline_start + # members to just be ones which were timeline senders, which then ensures + # all of the rest get included in the state block (if we need to know + # about them). + state_filter = StateFilter.all() + + state_at_previous_sync = await self.get_state_at( + room_id, + stream_position=since_token, + state_filter=state_filter, + await_full_state=await_full_state, + ) + + if batch: + state_at_timeline_end = ( + await self._state_storage_controller.get_state_ids_for_event( + batch.events[-1].event_id, + state_filter=state_filter, + await_full_state=await_full_state, + ) + ) + else: + # We can get here if the user has ignored the senders of all + # the recent events. + state_at_timeline_end = await self.get_state_at( + room_id, + stream_position=now_token, + state_filter=state_filter, + await_full_state=await_full_state, + ) + + state_ids = _calculate_state( + timeline_contains=timeline_state, + timeline_start=state_at_timeline_start, + timeline_end=state_at_timeline_end, + previous_timeline_end=state_at_previous_sync, + lazy_load_members=lazy_load_members, + ) + else: + state_ids = {} + if lazy_load_members: + if members_to_fetch and batch.events: + # We're returning an incremental sync, with no + # "gap" since the previous sync, so normally there would be + # no state to return. + # But we're lazy-loading, so the client might need some more + # member events to understand the events in this timeline. + # So we fish out all the member events corresponding to the + # timeline here. The caller will then dedupe any redundant ones. + + state_ids = await self._state_storage_controller.get_state_ids_for_event( + batch.events[0].event_id, + # we only want members! + state_filter=StateFilter.from_types( + (EventTypes.Member, member) for member in members_to_fetch + ), + await_full_state=False, + ) + return state_ids + async def _find_missing_partial_state_memberships( self, room_id: str, @@ -1333,9 +1424,9 @@ class SyncHandler: and auth_event.state_key == member ): missing_members.discard(member) - additional_state_ids[ - (EventTypes.Member, member) - ] = auth_event.event_id + additional_state_ids[(EventTypes.Member, member)] = ( + auth_event.event_id + ) break if missing_members: @@ -2746,7 +2837,7 @@ class SyncResultBuilder: if self.since_token: for joined_sync in self.joined: it = itertools.chain( - joined_sync.timeline.events, joined_sync.state.values() + joined_sync.state.values(), joined_sync.timeline.events ) for event in it: if event.type == EventTypes.Member: @@ -2758,13 +2849,20 @@ class SyncResultBuilder: newly_joined_or_invited_or_knocked_users.add( event.state_key ) + # If the user left and rejoined in the same batch, they + # count as a newly-joined user, *not* a newly-left user. + newly_left_users.discard(event.state_key) else: prev_content = event.unsigned.get("prev_content", {}) prev_membership = prev_content.get("membership", None) if prev_membership == Membership.JOIN: newly_left_users.add(event.state_key) + # If the user joined and left in the same batch, they + # count as a newly-left user, not a newly-joined user. + newly_joined_or_invited_or_knocked_users.discard( + event.state_key + ) - newly_left_users -= newly_joined_or_invited_or_knocked_users return newly_joined_or_invited_or_knocked_users, newly_left_users diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py index a870fd1124..7e578cf462 100644 --- a/synapse/handlers/worker_lock.py +++ b/synapse/handlers/worker_lock.py @@ -182,12 +182,15 @@ class WorkerLocksHandler: if not locks: return - def _wake_deferred(deferred: defer.Deferred) -> None: - if not deferred.called: - deferred.callback(None) - - for lock in locks: - self._clock.call_later(0, _wake_deferred, lock.deferred) + def _wake_all_locks( + locks: Collection[Union[WaitingLock, WaitingMultiLock]] + ) -> None: + for lock in locks: + deferred = lock.deferred + if not deferred.called: + deferred.callback(None) + + self._clock.call_later(0, _wake_all_locks, locks) @wrap_as_background_process("_cleanup_locks") async def _cleanup_locks(self) -> None: diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 884ecdacdd..c73a589e6c 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -931,8 +931,7 @@ class MatrixFederationHttpClient: try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, backoff_on_all_error_codes: bool = False, - ) -> JsonDict: - ... + ) -> JsonDict: ... @overload async def put_json( @@ -949,8 +948,7 @@ class MatrixFederationHttpClient: try_trailing_slash_on_400: bool = False, parser: Optional[ByteParser[T]] = None, backoff_on_all_error_codes: bool = False, - ) -> T: - ... + ) -> T: ... async def put_json( self, @@ -1140,8 +1138,7 @@ class MatrixFederationHttpClient: ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, - ) -> JsonDict: - ... + ) -> JsonDict: ... @overload async def get_json( @@ -1154,8 +1151,7 @@ class MatrixFederationHttpClient: ignore_backoff: bool = ..., try_trailing_slash_on_400: bool = ..., parser: ByteParser[T] = ..., - ) -> T: - ... + ) -> T: ... async def get_json( self, @@ -1236,8 +1232,7 @@ class MatrixFederationHttpClient: ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, - ) -> Tuple[JsonDict, Dict[bytes, List[bytes]]]: - ... + ) -> Tuple[JsonDict, Dict[bytes, List[bytes]]]: ... @overload async def get_json_with_headers( @@ -1250,8 +1245,7 @@ class MatrixFederationHttpClient: ignore_backoff: bool = ..., try_trailing_slash_on_400: bool = ..., parser: ByteParser[T] = ..., - ) -> Tuple[T, Dict[bytes, List[bytes]]]: - ... + ) -> Tuple[T, Dict[bytes, List[bytes]]]: ... async def get_json_with_headers( self, diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index b22eb727b1..b73d06f1d3 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -61,20 +61,17 @@ logger = logging.getLogger(__name__) @overload -def parse_integer(request: Request, name: str, default: int) -> int: - ... +def parse_integer(request: Request, name: str, default: int) -> int: ... @overload -def parse_integer(request: Request, name: str, *, required: Literal[True]) -> int: - ... +def parse_integer(request: Request, name: str, *, required: Literal[True]) -> int: ... @overload def parse_integer( request: Request, name: str, default: Optional[int] = None, required: bool = False -) -> Optional[int]: - ... +) -> Optional[int]: ... def parse_integer( @@ -105,8 +102,7 @@ def parse_integer_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[int] = None, -) -> Optional[int]: - ... +) -> Optional[int]: ... @overload @@ -115,8 +111,7 @@ def parse_integer_from_args( name: str, *, required: Literal[True], -) -> int: - ... +) -> int: ... @overload @@ -125,8 +120,7 @@ def parse_integer_from_args( name: str, default: Optional[int] = None, required: bool = False, -) -> Optional[int]: - ... +) -> Optional[int]: ... def parse_integer_from_args( @@ -172,20 +166,17 @@ def parse_integer_from_args( @overload -def parse_boolean(request: Request, name: str, default: bool) -> bool: - ... +def parse_boolean(request: Request, name: str, default: bool) -> bool: ... @overload -def parse_boolean(request: Request, name: str, *, required: Literal[True]) -> bool: - ... +def parse_boolean(request: Request, name: str, *, required: Literal[True]) -> bool: ... @overload def parse_boolean( request: Request, name: str, default: Optional[bool] = None, required: bool = False -) -> Optional[bool]: - ... +) -> Optional[bool]: ... def parse_boolean( @@ -216,8 +207,7 @@ def parse_boolean_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, default: bool, -) -> bool: - ... +) -> bool: ... @overload @@ -226,8 +216,7 @@ def parse_boolean_from_args( name: str, *, required: Literal[True], -) -> bool: - ... +) -> bool: ... @overload @@ -236,8 +225,7 @@ def parse_boolean_from_args( name: str, default: Optional[bool] = None, required: bool = False, -) -> Optional[bool]: - ... +) -> Optional[bool]: ... def parse_boolean_from_args( @@ -289,8 +277,7 @@ def parse_bytes_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[bytes] = None, -) -> Optional[bytes]: - ... +) -> Optional[bytes]: ... @overload @@ -300,8 +287,7 @@ def parse_bytes_from_args( default: Literal[None] = None, *, required: Literal[True], -) -> bytes: - ... +) -> bytes: ... @overload @@ -310,8 +296,7 @@ def parse_bytes_from_args( name: str, default: Optional[bytes] = None, required: bool = False, -) -> Optional[bytes]: - ... +) -> Optional[bytes]: ... def parse_bytes_from_args( @@ -355,8 +340,7 @@ def parse_string( *, allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", -) -> str: - ... +) -> str: ... @overload @@ -367,8 +351,7 @@ def parse_string( required: Literal[True], allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", -) -> str: - ... +) -> str: ... @overload @@ -380,8 +363,7 @@ def parse_string( required: bool = False, allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", -) -> Optional[str]: - ... +) -> Optional[str]: ... def parse_string( @@ -437,8 +419,7 @@ def parse_enum( name: str, E: Type[EnumT], default: EnumT, -) -> EnumT: - ... +) -> EnumT: ... @overload @@ -448,8 +429,7 @@ def parse_enum( E: Type[EnumT], *, required: Literal[True], -) -> EnumT: - ... +) -> EnumT: ... def parse_enum( @@ -526,8 +506,7 @@ def parse_strings_from_args( *, allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", -) -> Optional[List[str]]: - ... +) -> Optional[List[str]]: ... @overload @@ -538,8 +517,7 @@ def parse_strings_from_args( *, allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", -) -> List[str]: - ... +) -> List[str]: ... @overload @@ -550,8 +528,7 @@ def parse_strings_from_args( required: Literal[True], allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", -) -> List[str]: - ... +) -> List[str]: ... @overload @@ -563,8 +540,7 @@ def parse_strings_from_args( required: bool = False, allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", -) -> Optional[List[str]]: - ... +) -> Optional[List[str]]: ... def parse_strings_from_args( @@ -625,8 +601,7 @@ def parse_string_from_args( *, allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", -) -> Optional[str]: - ... +) -> Optional[str]: ... @overload @@ -638,8 +613,7 @@ def parse_string_from_args( required: Literal[True], allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", -) -> str: - ... +) -> str: ... @overload @@ -650,8 +624,7 @@ def parse_string_from_args( required: bool = False, allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", -) -> Optional[str]: - ... +) -> Optional[str]: ... def parse_string_from_args( @@ -704,22 +677,19 @@ def parse_string_from_args( @overload -def parse_json_value_from_request(request: Request) -> JsonDict: - ... +def parse_json_value_from_request(request: Request) -> JsonDict: ... @overload def parse_json_value_from_request( request: Request, allow_empty_body: Literal[False] -) -> JsonDict: - ... +) -> JsonDict: ... @overload def parse_json_value_from_request( request: Request, allow_empty_body: bool = False -) -> Optional[JsonDict]: - ... +) -> Optional[JsonDict]: ... def parse_json_value_from_request( @@ -847,7 +817,6 @@ def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None: class RestServlet: - """A Synapse REST Servlet. An implementing class can either provide its own custom 'register' method, diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 548d255b69..4650b60962 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -744,8 +744,7 @@ def preserve_fn( @overload -def preserve_fn(f: Callable[P, R]) -> Callable[P, "defer.Deferred[R]"]: - ... +def preserve_fn(f: Callable[P, R]) -> Callable[P, "defer.Deferred[R]"]: ... def preserve_fn( @@ -774,15 +773,10 @@ def run_in_background( @overload def run_in_background( f: Callable[P, R], *args: P.args, **kwargs: P.kwargs -) -> "defer.Deferred[R]": - ... +) -> "defer.Deferred[R]": ... -def run_in_background( # type: ignore[misc] - # The `type: ignore[misc]` above suppresses - # "Overloaded function implementation does not accept all possible arguments of signature 1" - # "Overloaded function implementation does not accept all possible arguments of signature 2" - # which seems like a bug in mypy. +def run_in_background( f: Union[ Callable[P, R], Callable[P, Awaitable[R]], diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 78b9fffbfb..7a3c805cc5 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -388,15 +388,13 @@ def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]: @overload def ensure_active_span( message: str, -) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]: - ... +) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]: ... @overload def ensure_active_span( message: str, ret: T -) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]: - ... +) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]: ... def ensure_active_span( diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 52859ed490..0e875132f6 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -1002,9 +1002,9 @@ class MediaRepository: ) t_width = min(m_width, t_width) t_height = min(m_height, t_height) - thumbnails[ - (t_width, t_height, requirement.media_type) - ] = requirement.method + thumbnails[(t_width, t_height, requirement.media_type)] = ( + requirement.method + ) # Now we generate the thumbnails for each dimension, store it for (t_width, t_height, t_type), t_method in thumbnails.items(): diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py index 6b4c64f7a5..bd25985686 100644 --- a/synapse/metrics/jemalloc.py +++ b/synapse/metrics/jemalloc.py @@ -42,14 +42,12 @@ class JemallocStats: @overload def _mallctl( self, name: str, read: Literal[True] = True, write: Optional[int] = None - ) -> int: - ... + ) -> int: ... @overload def _mallctl( self, name: str, read: Literal[False], write: Optional[int] = None - ) -> None: - ... + ) -> None: ... def _mallctl( self, name: str, read: bool = True, write: Optional[int] = None diff --git a/synapse/module_api/callbacks/spamchecker_callbacks.py b/synapse/module_api/callbacks/spamchecker_callbacks.py index 6ec56a7f14..17079ff781 100644 --- a/synapse/module_api/callbacks/spamchecker_callbacks.py +++ b/synapse/module_api/callbacks/spamchecker_callbacks.py @@ -455,7 +455,7 @@ class SpamCheckerModuleApiCallbacks: # mypy complains that we can't reach this code because of the # return type in CHECK_EVENT_FOR_SPAM_CALLBACK, but we don't know # for sure that the module actually returns it. - logger.warning( + logger.warning( # type: ignore[unreachable] "Module returned invalid value, rejecting message as spam" ) res = "This message has been rejected as probable spam" diff --git a/synapse/notifier.py b/synapse/notifier.py index 62d954298c..e87333a80a 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -469,8 +469,7 @@ class Notifier: new_token: RoomStreamToken, users: Optional[Collection[Union[str, UserID]]] = None, rooms: Optional[StrCollection] = None, - ) -> None: - ... + ) -> None: ... @overload def on_new_event( @@ -479,8 +478,7 @@ class Notifier: new_token: MultiWriterStreamToken, users: Optional[Collection[Union[str, UserID]]] = None, rooms: Optional[StrCollection] = None, - ) -> None: - ... + ) -> None: ... @overload def on_new_event( @@ -497,8 +495,7 @@ class Notifier: new_token: int, users: Optional[Collection[Union[str, UserID]]] = None, rooms: Optional[StrCollection] = None, - ) -> None: - ... + ) -> None: ... def on_new_event( self, diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index b4bd88f308..f1ffc8115f 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -377,12 +377,14 @@ class Mailer: # # Note that many email clients will not render the unsubscribe link # unless DKIM, etc. is properly setup. - additional_headers={ - "List-Unsubscribe-Post": "List-Unsubscribe=One-Click", - "List-Unsubscribe": f"<{unsubscribe_link}>", - } - if unsubscribe_link - else None, + additional_headers=( + { + "List-Unsubscribe-Post": "List-Unsubscribe=One-Click", + "List-Unsubscribe": f"<{unsubscribe_link}>", + } + if unsubscribe_link + else None + ), ) async def _get_room_vars( diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index a82ad49e01..9aa8d90bfe 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -259,9 +259,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): url_args.append(txn_id) if cls.METHOD == "POST": - request_func: Callable[ - ..., Awaitable[Any] - ] = client.post_json_get_json + request_func: Callable[..., Awaitable[Any]] = ( + client.post_json_get_json + ) elif cls.METHOD == "PUT": request_func = client.put_json elif cls.METHOD == "GET": diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py index ce47d8035c..a95771b5f6 100644 --- a/synapse/replication/tcp/external_cache.py +++ b/synapse/replication/tcp/external_cache.py @@ -70,9 +70,9 @@ class ExternalCache: def __init__(self, hs: "HomeServer"): if hs.config.redis.redis_enabled: - self._redis_connection: Optional[ - "ConnectionHandler" - ] = hs.get_outbound_redis_connection() + self._redis_connection: Optional["ConnectionHandler"] = ( + hs.get_outbound_redis_connection() + ) else: self._redis_connection = None diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 07e0fb71f2..6da1d79168 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -237,10 +237,12 @@ class PurgeHistoryStatusRestServlet(RestServlet): raise NotFoundError("purge id '%s' not found" % purge_id) result: JsonDict = { - "status": purge_task.status - if purge_task.status == TaskStatus.COMPLETE - or purge_task.status == TaskStatus.FAILED - else "active", + "status": ( + purge_task.status + if purge_task.status == TaskStatus.COMPLETE + or purge_task.status == TaskStatus.FAILED + else "active" + ), } if purge_task.error: result["error"] = purge_task.error diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index a9645e4af7..4e34e46512 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -1184,12 +1184,14 @@ class RateLimitRestServlet(RestServlet): # convert `null` to `0` for consistency # both values do the same in retelimit handler ret = { - "messages_per_second": 0 - if ratelimit.messages_per_second is None - else ratelimit.messages_per_second, - "burst_count": 0 - if ratelimit.burst_count is None - else ratelimit.burst_count, + "messages_per_second": ( + 0 + if ratelimit.messages_per_second is None + else ratelimit.messages_per_second + ), + "burst_count": ( + 0 if ratelimit.burst_count is None else ratelimit.burst_count + ), } else: ret = {} diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index 0cdc4cc4f7..12ffca984f 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -112,9 +112,9 @@ class AccountDataServlet(RestServlet): self._hs.config.experimental.msc4010_push_rules_account_data and account_data_type == AccountDataTypes.PUSH_RULES ): - account_data: Optional[ - JsonMapping - ] = await self._push_rules_handler.push_rules_for_user(requester.user) + account_data: Optional[JsonMapping] = ( + await self._push_rules_handler.push_rules_for_user(requester.user) + ) else: account_data = await self.store.get_global_account_data_by_type_for_user( user_id, account_data_type diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 3af2b7dfd9..2b103ca6a8 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -313,12 +313,12 @@ class SyncRestServlet(RestServlet): # https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md # states that this field should always be included, as long as the server supports the feature. - response[ - "org.matrix.msc2732.device_unused_fallback_key_types" - ] = sync_result.device_unused_fallback_key_types - response[ - "device_unused_fallback_key_types" - ] = sync_result.device_unused_fallback_key_types + response["org.matrix.msc2732.device_unused_fallback_key_types"] = ( + sync_result.device_unused_fallback_key_types + ) + response["device_unused_fallback_key_types"] = ( + sync_result.device_unused_fallback_key_types + ) if joined: response["rooms"][Membership.JOIN] = joined @@ -543,9 +543,9 @@ class SyncRestServlet(RestServlet): if room.unread_thread_notifications: result["unread_thread_notifications"] = room.unread_thread_notifications if self._msc3773_enabled: - result[ - "org.matrix.msc3773.unread_thread_notifications" - ] = room.unread_thread_notifications + result["org.matrix.msc3773.unread_thread_notifications"] = ( + room.unread_thread_notifications + ) result["summary"] = room.summary if self._msc2654_enabled: result["org.matrix.msc2654.unread_count"] = room.unread_count diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 6afe4a7bcc..dc7325fc57 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -191,10 +191,10 @@ class RemoteKey(RestServlet): server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {} for server_name, key_ids in query.items(): if key_ids: - results: Mapping[ - str, Optional[FetchKeyResultForRemote] - ] = await self.store.get_server_keys_json_for_remote( - server_name, key_ids + results: Mapping[str, Optional[FetchKeyResultForRemote]] = ( + await self.store.get_server_keys_json_for_remote( + server_name, key_ids + ) ) else: results = await self.store.get_all_server_keys_json_for_remote( diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 015e49ab81..72b291889b 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -603,15 +603,15 @@ class StateResolutionHandler: self.resolve_linearizer = Linearizer(name="state_resolve_lock") # dict of set of event_ids -> _StateCacheEntry. - self._state_cache: ExpiringCache[ - FrozenSet[int], _StateCacheEntry - ] = ExpiringCache( - cache_name="state_cache", - clock=self.clock, - max_len=100000, - expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, - iterable=True, - reset_expiry_on_get=True, + self._state_cache: ExpiringCache[FrozenSet[int], _StateCacheEntry] = ( + ExpiringCache( + cache_name="state_cache", + clock=self.clock, + max_len=100000, + expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, + iterable=True, + reset_expiry_on_get=True, + ) ) # diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 8de16db1d0..da926ad146 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -52,8 +52,7 @@ class Clock(Protocol): # This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests. # We only ever sleep(0) though, so that other async functions can make forward # progress without waiting for stateres to complete. - def sleep(self, duration_ms: float) -> Awaitable[None]: - ... + def sleep(self, duration_ms: float) -> Awaitable[None]: ... class StateResolutionStore(Protocol): @@ -61,13 +60,11 @@ class StateResolutionStore(Protocol): # TestStateResolutionStore in tests. def get_events( self, event_ids: StrCollection, allow_rejected: bool = False - ) -> Awaitable[Dict[str, EventBase]]: - ... + ) -> Awaitable[Dict[str, EventBase]]: ... def get_auth_chain_difference( self, room_id: str, state_sets: List[Set[str]] - ) -> Awaitable[Set[str]]: - ... + ) -> Awaitable[Set[str]]: ... # We want to await to the reactor occasionally during state res when dealing @@ -742,8 +739,7 @@ async def _get_event( event_map: Dict[str, EventBase], state_res_store: StateResolutionStore, allow_none: Literal[False] = False, -) -> EventBase: - ... +) -> EventBase: ... @overload @@ -753,8 +749,7 @@ async def _get_event( event_map: Dict[str, EventBase], state_res_store: StateResolutionStore, allow_none: Literal[True], -) -> Optional[EventBase]: - ... +) -> Optional[EventBase]: ... async def _get_event( diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 9df4edee38..f473294070 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -836,9 +836,9 @@ class BackgroundUpdater: c.execute(sql) if isinstance(self.db_pool.engine, engines.PostgresEngine): - runner: Optional[ - Callable[[LoggingDatabaseConnection], None] - ] = create_index_psql + runner: Optional[Callable[[LoggingDatabaseConnection], None]] = ( + create_index_psql + ) elif psql_only: runner = None else: diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 69d5999c0a..84699a2ee1 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -773,9 +773,9 @@ class EventsPersistenceStorageController: ) # Remove any events which are prev_events of any existing events. - existing_prevs: Collection[ - str - ] = await self.persist_events_store._get_events_which_are_prevs(result) + existing_prevs: Collection[str] = ( + await self.persist_events_store._get_events_which_are_prevs(result) + ) result.difference_update(existing_prevs) # Finally handle the case where the new events have soft-failed prev diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 8dc9080842..d9c85e411e 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -111,8 +111,7 @@ class _PoolConnection(Connection): A Connection from twisted.enterprise.adbapi.Connection. """ - def reconnect(self) -> None: - ... + def reconnect(self) -> None: ... def make_pool( @@ -914,9 +913,9 @@ class DatabasePool: try: with opentracing.start_active_span(f"db.{desc}"): - result = await self.runWithConnection( + result: R = await self.runWithConnection( # mypy seems to have an issue with this, maybe a bug? - self.new_transaction, # type: ignore[arg-type] + self.new_transaction, desc, after_callbacks, async_after_callbacks, @@ -935,7 +934,7 @@ class DatabasePool: await async_callback(*async_args, **async_kwargs) for after_callback, after_args, after_kwargs in after_callbacks: after_callback(*after_args, **after_kwargs) - return cast(R, result) + return result except Exception: for exception_callback, after_args, after_kwargs in exception_callbacks: exception_callback(*after_args, **after_kwargs) @@ -1603,8 +1602,7 @@ class DatabasePool: retcols: Collection[str], allow_none: Literal[False] = False, desc: str = "simple_select_one", - ) -> Tuple[Any, ...]: - ... + ) -> Tuple[Any, ...]: ... @overload async def simple_select_one( @@ -1614,8 +1612,7 @@ class DatabasePool: retcols: Collection[str], allow_none: Literal[True] = True, desc: str = "simple_select_one", - ) -> Optional[Tuple[Any, ...]]: - ... + ) -> Optional[Tuple[Any, ...]]: ... async def simple_select_one( self, @@ -1654,8 +1651,7 @@ class DatabasePool: retcol: str, allow_none: Literal[False] = False, desc: str = "simple_select_one_onecol", - ) -> Any: - ... + ) -> Any: ... @overload async def simple_select_one_onecol( @@ -1665,8 +1661,7 @@ class DatabasePool: retcol: str, allow_none: Literal[True] = True, desc: str = "simple_select_one_onecol", - ) -> Optional[Any]: - ... + ) -> Optional[Any]: ... async def simple_select_one_onecol( self, @@ -1706,8 +1701,7 @@ class DatabasePool: keyvalues: Dict[str, Any], retcol: str, allow_none: Literal[False] = False, - ) -> Any: - ... + ) -> Any: ... @overload @classmethod @@ -1718,8 +1712,7 @@ class DatabasePool: keyvalues: Dict[str, Any], retcol: str, allow_none: Literal[True] = True, - ) -> Optional[Any]: - ... + ) -> Optional[Any]: ... @classmethod def simple_select_one_onecol_txn( @@ -2501,8 +2494,7 @@ def make_tuple_in_list_sql_clause( database_engine: BaseDatabaseEngine, columns: Tuple[str, str], iterable: Collection[Tuple[Any, Any]], -) -> Tuple[str, list]: - ... +) -> Tuple[str, list]: ... def make_tuple_in_list_sql_clause( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 3e011f3340..8dbcb3f5a0 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1701,9 +1701,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. - self.device_id_exists_cache: LruCache[ - Tuple[str, str], Literal[True] - ] = LruCache(cache_name="device_id_exists", max_size=10000) + self.device_id_exists_cache: LruCache[Tuple[str, str], Literal[True]] = ( + LruCache(cache_name="device_id_exists", max_size=10000) + ) async def store_device( self, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index c96371a0d3..b219ea70ee 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -256,8 +256,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker self, query_list: Collection[Tuple[str, Optional[str]]], include_all_devices: Literal[False] = False, - ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]: - ... + ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]: ... @overload async def get_e2e_device_keys_and_signatures( @@ -265,8 +264,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker query_list: Collection[Tuple[str, Optional[str]]], include_all_devices: bool = False, include_deleted_devices: Literal[False] = False, - ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]: - ... + ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]: ... @overload async def get_e2e_device_keys_and_signatures( @@ -274,8 +272,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker query_list: Collection[Tuple[str, Optional[str]]], include_all_devices: Literal[True], include_deleted_devices: Literal[True], - ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: - ... + ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: ... @trace @cancellable diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index d5942a10b2..a6fda3f43c 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1292,9 +1292,9 @@ class PersistEventsStore: Returns: filtered list """ - new_events_and_contexts: OrderedDict[ - str, Tuple[EventBase, EventContext] - ] = OrderedDict() + new_events_and_contexts: OrderedDict[str, Tuple[EventBase, EventContext]] = ( + OrderedDict() + ) for event, context in events_and_contexts: prev_event_context = new_events_and_contexts.get(event.event_id) if prev_event_context: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 9c3775bb7c..81fccfbccb 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -263,13 +263,13 @@ class EventsWorkerStore(SQLBaseStore): 5 * 60 * 1000, ) - self._get_event_cache: AsyncLruCache[ - Tuple[str], EventCacheEntry - ] = AsyncLruCache( - cache_name="*getEvent*", - max_size=hs.config.caches.event_cache_size, - # `extra_index_cb` Returns a tuple as that is the key type - extra_index_cb=lambda _, v: (v.event.room_id,), + self._get_event_cache: AsyncLruCache[Tuple[str], EventCacheEntry] = ( + AsyncLruCache( + cache_name="*getEvent*", + max_size=hs.config.caches.event_cache_size, + # `extra_index_cb` Returns a tuple as that is the key type + extra_index_cb=lambda _, v: (v.event.room_id,), + ) ) # Map from event ID to a deferred that will result in a map from event @@ -459,8 +459,7 @@ class EventsWorkerStore(SQLBaseStore): allow_rejected: bool = ..., allow_none: Literal[False] = ..., check_room_id: Optional[str] = ..., - ) -> EventBase: - ... + ) -> EventBase: ... @overload async def get_event( @@ -471,8 +470,7 @@ class EventsWorkerStore(SQLBaseStore): allow_rejected: bool = ..., allow_none: Literal[True] = ..., check_room_id: Optional[str] = ..., - ) -> Optional[EventBase]: - ... + ) -> Optional[EventBase]: ... @cancellable async def get_event( @@ -800,9 +798,9 @@ class EventsWorkerStore(SQLBaseStore): # to all the events we pulled from the DB (this will result in this # function returning more events than requested, but that can happen # already due to `_get_events_from_db`). - fetching_deferred: ObservableDeferred[ - Dict[str, EventCacheEntry] - ] = ObservableDeferred(defer.Deferred(), consumeErrors=True) + fetching_deferred: ObservableDeferred[Dict[str, EventCacheEntry]] = ( + ObservableDeferred(defer.Deferred(), consumeErrors=True) + ) for event_id in missing_events_ids: self._current_event_fetches[event_id] = fetching_deferred @@ -1871,14 +1869,14 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (-last_id, -current_id, instance_name, limit)) - new_event_updates: List[ - Tuple[int, Tuple[str, str, str, str, str, str]] - ] = [] + new_event_updates: List[Tuple[int, Tuple[str, str, str, str, str, str]]] = ( + [] + ) row: Tuple[int, str, str, str, str, str, str] # Type safety: iterating over `txn` yields `Tuple`, i.e. # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a # variadic tuple to a fixed length tuple and flags it up as an error. - for row in txn: # type: ignore[assignment] + for row in txn: new_event_updates.append((row[0], row[1:])) limited = False @@ -1905,7 +1903,7 @@ class EventsWorkerStore(SQLBaseStore): # Type safety: iterating over `txn` yields `Tuple`, i.e. # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a # variadic tuple to a fixed length tuple and flags it up as an error. - for row in txn: # type: ignore[assignment] + for row in txn: new_event_updates.append((row[0], row[1:])) if len(new_event_updates) >= limit: diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index 0794cc6d25..8277ad8c33 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -79,9 +79,9 @@ class LockStore(SQLBaseStore): # A map from `(lock_name, lock_key)` to lock that we think we # currently hold. - self._live_lock_tokens: WeakValueDictionary[ - Tuple[str, str], Lock - ] = WeakValueDictionary() + self._live_lock_tokens: WeakValueDictionary[Tuple[str, str], Lock] = ( + WeakValueDictionary() + ) # A map from `(lock_name, lock_key, token)` to read/write lock that we # think we currently hold. For a given lock_name/lock_key, there can be diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index b5ed1bf9c8..6128332af8 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -158,9 +158,9 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): ) if hs.config.media.can_load_media_repo: - self.unused_expiration_time: Optional[ - int - ] = hs.config.media.unused_expiration_time + self.unused_expiration_time: Optional[int] = ( + hs.config.media.unused_expiration_time + ) else: self.unused_expiration_time = None diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 8a426d2875..d513c42530 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -394,9 +394,9 @@ class ReceiptsWorkerStore(SQLBaseStore): content: JsonDict = {} for receipt_type, user_id, event_id, data in rows: - content.setdefault(event_id, {}).setdefault(receipt_type, {})[ - user_id - ] = db_to_json(data) + content.setdefault(event_id, {}).setdefault(receipt_type, {})[user_id] = ( + db_to_json(data) + ) return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}] @@ -483,9 +483,9 @@ class ReceiptsWorkerStore(SQLBaseStore): if user_id in receipt_type_dict: # existing receipt # is the existing receipt threaded and we are currently processing an unthreaded one? if "thread_id" in receipt_type_dict[user_id] and not thread_id: - receipt_type_dict[ - user_id - ] = receipt_data # replace with unthreaded one + receipt_type_dict[user_id] = ( + receipt_data # replace with unthreaded one + ) else: # receipt does not exist, just set it receipt_type_dict[user_id] = receipt_data if thread_id: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 3220d515d9..b2a67aff89 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -768,12 +768,10 @@ class StateMapWrapper(Dict[StateKey, str]): return super().__getitem__(key) @overload - def get(self, key: Tuple[str, str]) -> Optional[str]: - ... + def get(self, key: Tuple[str, str]) -> Optional[str]: ... @overload - def get(self, key: Tuple[str, str], default: Union[str, _T]) -> Union[str, _T]: - ... + def get(self, key: Tuple[str, str], default: Union[str, _T]) -> Union[str, _T]: ... def get( self, key: StateKey, default: Union[str, _T, None] = None diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 19041cc35b..7ab6003f61 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -988,8 +988,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn: LoggingTransaction, event_id: str, allow_none: Literal[False] = False, - ) -> int: - ... + ) -> int: ... @overload def get_stream_id_for_event_txn( @@ -997,8 +996,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn: LoggingTransaction, event_id: str, allow_none: bool = False, - ) -> Optional[int]: - ... + ) -> Optional[int]: ... def get_stream_id_for_event_txn( self, @@ -1476,12 +1474,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): _EventDictReturn(event_id, topological_ordering, stream_ordering) for event_id, instance_name, topological_ordering, stream_ordering in txn if _filter_results( - lower_token=to_token - if direction == Direction.BACKWARDS - else from_token, - upper_token=from_token - if direction == Direction.BACKWARDS - else to_token, + lower_token=( + to_token if direction == Direction.BACKWARDS else from_token + ), + upper_token=( + from_token if direction == Direction.BACKWARDS else to_token + ), instance_name=instance_name, topological_ordering=topological_ordering, stream_ordering=stream_ordering, diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py index 7b95616432..4956870b1a 100644 --- a/synapse/storage/databases/main/task_scheduler.py +++ b/synapse/storage/databases/main/task_scheduler.py @@ -136,12 +136,12 @@ class TaskSchedulerWorkerStore(SQLBaseStore): "status": task.status, "timestamp": task.timestamp, "resource_id": task.resource_id, - "params": None - if task.params is None - else json_encoder.encode(task.params), - "result": None - if task.result is None - else json_encoder.encode(task.result), + "params": ( + None if task.params is None else json_encoder.encode(task.params) + ), + "result": ( + None if task.result is None else json_encoder.encode(task.result) + ), "error": task.error, }, desc="insert_scheduled_task", diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index a1c4b8c6c3..0513e7dc06 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -745,9 +745,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): p.user_id, get_localpart_from_id(p.user_id), get_domain_from_id(p.user_id), - _filter_text_for_index(p.display_name) - if p.display_name - else None, + ( + _filter_text_for_index(p.display_name) + if p.display_name + else None + ), ) for p in profiles ], diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index e64495ba8d..d4ac74c1ee 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -120,11 +120,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # TODO: this hasn't been tuned yet 50000, ) - self._state_group_members_cache: DictionaryCache[ - int, StateKey, str - ] = DictionaryCache( - "*stateGroupMembersCache*", - 500000, + self._state_group_members_cache: DictionaryCache[int, StateKey, str] = ( + DictionaryCache( + "*stateGroupMembersCache*", + 500000, + ) ) def get_max_state_group_txn(txn: Cursor) -> int: diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index 8c29236b59..ad222e7e2d 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -48,8 +48,7 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM @property @abc.abstractmethod - def single_threaded(self) -> bool: - ... + def single_threaded(self) -> bool: ... @property @abc.abstractmethod @@ -68,8 +67,7 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM @abc.abstractmethod def check_database( self, db_conn: ConnectionType, allow_outdated_version: bool = False - ) -> None: - ... + ) -> None: ... @abc.abstractmethod def check_new_database(self, txn: CursorType) -> None: @@ -79,27 +77,22 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM ... @abc.abstractmethod - def convert_param_style(self, sql: str) -> str: - ... + def convert_param_style(self, sql: str) -> str: ... # This method would ideally take a plain ConnectionType, but it seems that # the Sqlite engine expects to use LoggingDatabaseConnection.cursor # instead of sqlite3.Connection.cursor: only the former takes a txn_name. @abc.abstractmethod - def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None: - ... + def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None: ... @abc.abstractmethod - def is_deadlock(self, error: Exception) -> bool: - ... + def is_deadlock(self, error: Exception) -> bool: ... @abc.abstractmethod - def is_connection_closed(self, conn: ConnectionType) -> bool: - ... + def is_connection_closed(self, conn: ConnectionType) -> bool: ... @abc.abstractmethod - def lock_table(self, txn: Cursor, table: str) -> None: - ... + def lock_table(self, txn: Cursor, table: str) -> None: ... @property @abc.abstractmethod diff --git a/synapse/storage/types.py b/synapse/storage/types.py index b4e0a8f576..74f60cc590 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -42,20 +42,17 @@ SQLQueryParameters = Union[Sequence[Any], Mapping[str, Any]] class Cursor(Protocol): - def execute(self, sql: str, parameters: SQLQueryParameters = ...) -> Any: - ... + def execute(self, sql: str, parameters: SQLQueryParameters = ...) -> Any: ... - def executemany(self, sql: str, parameters: Sequence[SQLQueryParameters]) -> Any: - ... + def executemany( + self, sql: str, parameters: Sequence[SQLQueryParameters] + ) -> Any: ... - def fetchone(self) -> Optional[Tuple]: - ... + def fetchone(self) -> Optional[Tuple]: ... - def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]: - ... + def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]: ... - def fetchall(self) -> List[Tuple]: - ... + def fetchall(self) -> List[Tuple]: ... @property def description( @@ -70,36 +67,28 @@ class Cursor(Protocol): def rowcount(self) -> int: return 0 - def __iter__(self) -> Iterator[Tuple]: - ... + def __iter__(self) -> Iterator[Tuple]: ... - def close(self) -> None: - ... + def close(self) -> None: ... class Connection(Protocol): - def cursor(self) -> Cursor: - ... + def cursor(self) -> Cursor: ... - def close(self) -> None: - ... + def close(self) -> None: ... - def commit(self) -> None: - ... + def commit(self) -> None: ... - def rollback(self) -> None: - ... + def rollback(self) -> None: ... - def __enter__(self) -> "Connection": - ... + def __enter__(self) -> "Connection": ... def __exit__( self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], - ) -> Optional[bool]: - ... + ) -> Optional[bool]: ... class DBAPI2Module(Protocol): @@ -129,24 +118,20 @@ class DBAPI2Module(Protocol): # explain why this is necessary for safety. TL;DR: we shouldn't be able to write # to `x`, only read from it. See also https://github.com/python/mypy/issues/6002 . @property - def Warning(self) -> Type[Exception]: - ... + def Warning(self) -> Type[Exception]: ... @property - def Error(self) -> Type[Exception]: - ... + def Error(self) -> Type[Exception]: ... # Errors are divided into `InterfaceError`s (something went wrong in the database # driver) and `DatabaseError`s (something went wrong in the database). These are # both subclasses of `Error`, but we can't currently express this in type # annotations due to https://github.com/python/mypy/issues/8397 @property - def InterfaceError(self) -> Type[Exception]: - ... + def InterfaceError(self) -> Type[Exception]: ... @property - def DatabaseError(self) -> Type[Exception]: - ... + def DatabaseError(self) -> Type[Exception]: ... # Everything below is a subclass of `DatabaseError`. @@ -155,8 +140,7 @@ class DBAPI2Module(Protocol): # - An invalid date time was provided. # - A string contained a null code point. @property - def DataError(self) -> Type[Exception]: - ... + def DataError(self) -> Type[Exception]: ... # Roughly: something went wrong in the database, but it's not within the application # programmer's control. Examples: @@ -167,21 +151,18 @@ class DBAPI2Module(Protocol): # - The database ran out of resources, such as storage, memory, connections, etc. # - The database encountered an error from the operating system. @property - def OperationalError(self) -> Type[Exception]: - ... + def OperationalError(self) -> Type[Exception]: ... # Roughly: we've given the database data which breaks a rule we asked it to enforce. # Examples: # - Stop, criminal scum! You violated the foreign key constraint # - Also check constraints, non-null constraints, etc. @property - def IntegrityError(self) -> Type[Exception]: - ... + def IntegrityError(self) -> Type[Exception]: ... # Roughly: something went wrong within the database server itself. @property - def InternalError(self) -> Type[Exception]: - ... + def InternalError(self) -> Type[Exception]: ... # Roughly: the application did something silly that needs to be fixed. Examples: # - We don't have permissions to do something. @@ -189,13 +170,11 @@ class DBAPI2Module(Protocol): # - We tried to use a reserved name. # - We referred to a column that doesn't exist. @property - def ProgrammingError(self) -> Type[Exception]: - ... + def ProgrammingError(self) -> Type[Exception]: ... # Roughly: we've tried to do something that this database doesn't support. @property - def NotSupportedError(self) -> Type[Exception]: - ... + def NotSupportedError(self) -> Type[Exception]: ... # We originally wrote # def connect(self, *args, **kwargs) -> Connection: ... @@ -204,8 +183,7 @@ class DBAPI2Module(Protocol): # psycopg2.connect doesn't have a mandatory positional argument. Instead, we use # the following slightly unusual workaround. @property - def connect(self) -> Callable[..., Connection]: - ... + def connect(self) -> Callable[..., Connection]: ... __all__ = ["Cursor", "Connection", "DBAPI2Module"] diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 7466488157..dd7401ac8e 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -57,12 +57,13 @@ class _EventSourcesInner: class EventSources: def __init__(self, hs: "HomeServer"): self.sources = _EventSourcesInner( - # mypy previously warned that attribute.type is `Optional`, but we know it's + # attribute.type is `Optional`, but we know it's # never `None` here since all the attributes of `_EventSourcesInner` are # annotated. - # As of the stubs in attrs 22.1.0, `attr.fields()` now returns Any, - # so the call to `attribute.type` is not checked. - *(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner)) + *( + attribute.type(hs) # type: ignore[misc] + for attribute in attr.fields(_EventSourcesInner) + ) ) self.store = hs.get_datastores().main self._instance_name = hs.get_instance_name() diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi index 423ede5969..69837617f5 100644 --- a/synapse/synapse_rust/events.pyi +++ b/synapse/synapse_rust/events.pyi @@ -56,7 +56,7 @@ class EventInternalMetadata: (Added in synapse 0.99.0, so may be unreliable for events received before that) """ - ... + def get_send_on_behalf_of(self) -> Optional[str]: """Whether this server should send the event on behalf of another server. This is used by the federation "send_join" API to forward the initial join @@ -64,7 +64,7 @@ class EventInternalMetadata: returns a str with the name of the server this event is sent on behalf of. """ - ... + def need_to_check_redaction(self) -> bool: """Whether the redaction event needs to be rechecked when fetching from the database. @@ -75,7 +75,7 @@ class EventInternalMetadata: If the sender of the redaction event is allowed to redact any event due to auth rules, then this will always return false. """ - ... + def is_soft_failed(self) -> bool: """Whether the event has been soft failed. @@ -85,7 +85,7 @@ class EventInternalMetadata: 2. They should not be added to the forward extremities (and therefore not to current state). """ - ... + def should_proactively_send(self) -> bool: """Whether the event, if ours, should be sent to other clients and servers. @@ -93,14 +93,13 @@ class EventInternalMetadata: This is used for sending dummy events internally. Servers and clients can still explicitly fetch the event. """ - ... + def is_redacted(self) -> bool: """Whether the event has been redacted. This is used for efficiently checking whether an event has been marked as redacted without needing to make another database call. """ - ... + def is_notifiable(self) -> bool: """Whether this event can trigger a push notification""" - ... diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index d3ee718375..a88982a04c 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -976,12 +976,12 @@ class StreamToken: return attr.evolve(self, **{key.value: new_value}) @overload - def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken: - ... + def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken: ... @overload - def get_field(self, key: Literal[StreamKeyType.RECEIPT]) -> MultiWriterStreamToken: - ... + def get_field( + self, key: Literal[StreamKeyType.RECEIPT] + ) -> MultiWriterStreamToken: ... @overload def get_field( @@ -995,14 +995,12 @@ class StreamToken: StreamKeyType.TYPING, StreamKeyType.UN_PARTIAL_STATED_ROOMS, ], - ) -> int: - ... + ) -> int: ... @overload def get_field( self, key: StreamKeyType - ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]: - ... + ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]: ... def get_field( self, key: StreamKeyType diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 914d4fd747..70139beef2 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -284,15 +284,7 @@ async def yieldable_gather_results( try: return await make_deferred_yieldable( defer.gatherResults( - # type-ignore: mypy reports two errors: - # error: Argument 1 to "run_in_background" has incompatible type - # "Callable[[T, **P], Awaitable[R]]"; expected - # "Callable[[T, **P], Awaitable[R]]" [arg-type] - # error: Argument 2 to "run_in_background" has incompatible type - # "T"; expected "[T, **P.args]" [arg-type] - # The former looks like a mypy bug, and the latter looks like a - # false positive. - [run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type] + [run_in_background(func, item, *args, **kwargs) for item in iter], consumeErrors=True, ) ) @@ -338,7 +330,7 @@ async def yieldable_gather_results_delaying_cancellation( return await make_deferred_yieldable( delay_cancellation( defer.gatherResults( - [run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type] + [run_in_background(func, item, *args, **kwargs) for item in iter], consumeErrors=True, ) ) @@ -357,24 +349,21 @@ T4 = TypeVar("T4") @overload def gather_results( deferredList: Tuple[()], consumeErrors: bool = ... -) -> "defer.Deferred[Tuple[()]]": - ... +) -> "defer.Deferred[Tuple[()]]": ... @overload def gather_results( deferredList: Tuple["defer.Deferred[T1]"], consumeErrors: bool = ..., -) -> "defer.Deferred[Tuple[T1]]": - ... +) -> "defer.Deferred[Tuple[T1]]": ... @overload def gather_results( deferredList: Tuple["defer.Deferred[T1]", "defer.Deferred[T2]"], consumeErrors: bool = ..., -) -> "defer.Deferred[Tuple[T1, T2]]": - ... +) -> "defer.Deferred[Tuple[T1, T2]]": ... @overload @@ -383,8 +372,7 @@ def gather_results( "defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]" ], consumeErrors: bool = ..., -) -> "defer.Deferred[Tuple[T1, T2, T3]]": - ... +) -> "defer.Deferred[Tuple[T1, T2, T3]]": ... @overload @@ -396,8 +384,7 @@ def gather_results( "defer.Deferred[T4]", ], consumeErrors: bool = ..., -) -> "defer.Deferred[Tuple[T1, T2, T3, T4]]": - ... +) -> "defer.Deferred[Tuple[T1, T2, T3, T4]]": ... def gather_results( # type: ignore[misc] @@ -782,18 +769,15 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": @overload -def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]": - ... +def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]": ... @overload -def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]": - ... +def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]": ... @overload -def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]: - ... +def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]: ... def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]: diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 4245b7289c..1e6696332f 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -229,7 +229,7 @@ class DictionaryCache(Generic[KT, DKT, DV]): for dict_key in missing: # We explicitly add each dict key to the cache, so that cache hit # rates and LRU times for each key can be tracked separately. - value = entry.get(dict_key, _Sentinel.sentinel) # type: ignore[arg-type] + value = entry.get(dict_key, _Sentinel.sentinel) self.cache[(key, dict_key)] = _PerKeyValue(value) if value is not _Sentinel.sentinel: diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index a52ba59a34..8017c031ee 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -142,7 +142,7 @@ class ExpiringCache(Generic[KT, VT]): return default if self.iterable: - self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value)) # type: ignore[arg-type] + self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value)) else: self.metrics.inc_evictions(EvictionReason.invalidation) @@ -152,12 +152,10 @@ class ExpiringCache(Generic[KT, VT]): return key in self._cache @overload - def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]: - ... + def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]: ... @overload - def get(self, key: KT, default: T) -> Union[VT, T]: - ... + def get(self, key: KT, default: T) -> Union[VT, T]: ... def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]: try: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index a1b4f5b6a7..481a1a621e 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -580,8 +580,7 @@ class LruCache(Generic[KT, VT]): callbacks: Collection[Callable[[], None]] = ..., update_metrics: bool = ..., update_last_access: bool = ..., - ) -> Optional[VT]: - ... + ) -> Optional[VT]: ... @overload def cache_get( @@ -590,8 +589,7 @@ class LruCache(Generic[KT, VT]): callbacks: Collection[Callable[[], None]] = ..., update_metrics: bool = ..., update_last_access: bool = ..., - ) -> Union[T, VT]: - ... + ) -> Union[T, VT]: ... @synchronized def cache_get( @@ -634,16 +632,14 @@ class LruCache(Generic[KT, VT]): key: tuple, default: Literal[None] = None, update_metrics: bool = True, - ) -> Union[None, Iterable[Tuple[KT, VT]]]: - ... + ) -> Union[None, Iterable[Tuple[KT, VT]]]: ... @overload def cache_get_multi( key: tuple, default: T, update_metrics: bool = True, - ) -> Union[T, Iterable[Tuple[KT, VT]]]: - ... + ) -> Union[T, Iterable[Tuple[KT, VT]]]: ... @synchronized def cache_get_multi( @@ -728,12 +724,10 @@ class LruCache(Generic[KT, VT]): return value @overload - def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]: - ... + def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]: ... @overload - def cache_pop(key: KT, default: T) -> Union[T, VT]: - ... + def cache_pop(key: KT, default: T) -> Union[T, VT]: ... @synchronized def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]: diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index 082ad8cedb..b73f690b88 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -50,8 +50,7 @@ class _SelfSlice(Sized, Protocol): returned. """ - def __getitem__(self: S, i: slice) -> S: - ... + def __getitem__(self: S, i: slice) -> S: ... def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]: diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index dc9bddb00d..8ead72bb7a 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -177,9 +177,9 @@ class FederationRateLimiter: clock=clock, config=config, metrics_name=metrics_name ) - self.ratelimiters: DefaultDict[ - str, "_PerHostRatelimiter" - ] = collections.defaultdict(new_limiter) + self.ratelimiters: DefaultDict[str, "_PerHostRatelimiter"] = ( + collections.defaultdict(new_limiter) + ) with _rate_limiter_instances_lock: _rate_limiter_instances.add(self) diff --git a/synapse/visibility.py b/synapse/visibility.py index e58f649aaf..d1d478129f 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -129,9 +129,9 @@ async def filter_events_for_client( retention_policies: Dict[str, RetentionPolicy] = {} for room_id in room_ids: - retention_policies[ - room_id - ] = await storage.main.get_retention_policy_for_room(room_id) + retention_policies[room_id] = ( + await storage.main.get_retention_policy_for_room(room_id) + ) def allowed(event: EventBase) -> Optional[EventBase]: return _check_client_allowed_to_see_event( |