From 84169a82dcf7dfb6eb7d307ea7f5e33cb57f6e3f Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Thu, 18 Aug 2022 11:53:02 +0100 Subject: Avoid blocking lazy-loading `/sync`s during partial joins (#13477) Use a state filter or accept partial state in a few places where we request state, to avoid blocking. To make lazy-loading `/sync`s work, we need to provide the memberships of event senders, which are not guaranteed to be in the room state. Instead we dig through auth events for memberships to present to clients. The auth events of an event are guaranteed to contain a passable membership event, otherwise the event would have been rejected. Note that this only covers the common code paths encountered during testing. There has been no exhaustive checking of all sync code paths. Fixes #13146. Signed-off-by: Sean Quah --- synapse/handlers/sync.py | 253 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 223 insertions(+), 30 deletions(-) (limited to 'synapse/handlers/sync.py') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 3ca01391c9..b4d3f3958c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -16,9 +16,11 @@ import logging from typing import ( TYPE_CHECKING, Any, + Collection, Dict, FrozenSet, List, + Mapping, Optional, Sequence, Set, @@ -517,10 +519,17 @@ class SyncHandler: # ensure that we always include current state in the timeline current_state_ids: FrozenSet[str] = frozenset() if any(e.is_state() for e in recents): + # FIXME(faster_joins): We use the partial state here as + # we don't want to block `/sync` on finishing a lazy join. + # Which should be fine once + # https://github.com/matrix-org/synapse/issues/12989 is resolved, + # since we shouldn't reach here anymore? + # Note that we use the current state as a whitelist for filtering + # `recents`, so partial state is only a problem when a membership + # event turns up in `recents` but has not made it into the current + # state. current_state_ids_map = ( - await self._state_storage_controller.get_current_state_ids( - room_id - ) + await self.store.get_partial_current_state_ids(room_id) ) current_state_ids = frozenset(current_state_ids_map.values()) @@ -589,7 +598,13 @@ class SyncHandler: if any(e.is_state() for e in loaded_recents): # FIXME(faster_joins): We use the partial state here as # we don't want to block `/sync` on finishing a lazy join. - # Is this the correct way of doing it? + # Which should be fine once + # https://github.com/matrix-org/synapse/issues/12989 is resolved, + # since we shouldn't reach here anymore? + # Note that we use the current state as a whitelist for filtering + # `loaded_recents`, so partial state is only a problem when a + # membership event turns up in `loaded_recents` but has not made it + # into the current state. current_state_ids_map = ( await self.store.get_partial_current_state_ids(room_id) ) @@ -637,7 +652,10 @@ class SyncHandler: ) async def get_state_after_event( - self, event_id: str, state_filter: Optional[StateFilter] = None + self, + event_id: str, + state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """ Get the room state after the given event @@ -645,9 +663,14 @@ class SyncHandler: Args: event_id: event of interest state_filter: The state filter used to fetch state from the database. + await_full_state: if `True`, will block if we do not yet have complete state + at the event and `state_filter` is not satisfied by partial state. + Defaults to `True`. """ state_ids = await self._state_storage_controller.get_state_ids_for_event( - event_id, state_filter=state_filter or StateFilter.all() + event_id, + state_filter=state_filter or StateFilter.all(), + await_full_state=await_full_state, ) # using get_metadata_for_events here (instead of get_event) sidesteps an issue @@ -670,6 +693,7 @@ class SyncHandler: room_id: str, stream_position: StreamToken, state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """Get the room state at a particular stream position @@ -677,6 +701,9 @@ class SyncHandler: room_id: room for which to get state stream_position: point at which to get state state_filter: The state filter used to fetch state from the database. + await_full_state: if `True`, will block if we do not yet have complete state + at the last event in the room before `stream_position` and + `state_filter` is not satisfied by partial state. Defaults to `True`. """ # FIXME: This gets the state at the latest event before the stream ordering, # which might not be the same as the "current state" of the room at the time @@ -688,7 +715,9 @@ class SyncHandler: if last_event_id: state = await self.get_state_after_event( - last_event_id, state_filter=state_filter or StateFilter.all() + last_event_id, + state_filter=state_filter or StateFilter.all(), + await_full_state=await_full_state, ) else: @@ -891,7 +920,15 @@ class SyncHandler: with Measure(self.clock, "compute_state_delta"): # The memberships needed for events in the timeline. # Only calculated when `lazy_load_members` is on. - members_to_fetch = None + members_to_fetch: Optional[Set[str]] = None + + # A dictionary mapping user IDs to the first event in the timeline sent by + # them. Only calculated when `lazy_load_members` is on. + first_event_by_sender_map: Optional[Dict[str, EventBase]] = None + + # The contribution to the room state from state events in the timeline. + # Only contains the last event for any given state key. + timeline_state: StateMap[str] lazy_load_members = sync_config.filter_collection.lazy_load_members() include_redundant_members = ( @@ -902,10 +939,23 @@ class SyncHandler: # We only request state for the members needed to display the # timeline: - members_to_fetch = { - event.sender # FIXME: we also care about invite targets etc. - for event in batch.events - } + timeline_state = {} + + members_to_fetch = set() + first_event_by_sender_map = {} + for event in batch.events: + # Build the map from user IDs to the first timeline event they sent. + if event.sender not in first_event_by_sender_map: + first_event_by_sender_map[event.sender] = event + + # We need the event's sender, unless their membership was in a + # previous timeline event. + if (EventTypes.Member, event.sender) not in timeline_state: + members_to_fetch.add(event.sender) + # FIXME: we also care about invite targets etc. + + if event.is_state(): + timeline_state[(event.type, event.state_key)] = event.event_id if full_state: # always make sure we LL ourselves so we know we're in the room @@ -915,16 +965,21 @@ class SyncHandler: members_to_fetch.add(sync_config.user.to_string()) state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch) + + # We are happy to use partial state to compute the `/sync` response. + # Since partial state may not include the lazy-loaded memberships we + # require, we fix up the state response afterwards with memberships from + # auth events. + await_full_state = False else: - state_filter = StateFilter.all() + timeline_state = { + (event.type, event.state_key): event.event_id + for event in batch.events + if event.is_state() + } - # The contribution to the room state from state events in the timeline. - # Only contains the last event for any given state key. - timeline_state = { - (event.type, event.state_key): event.event_id - for event in batch.events - if event.is_state() - } + state_filter = StateFilter.all() + await_full_state = True # Now calculate the state to return in the sync response for the room. # This is more or less the change in state between the end of the previous @@ -936,19 +991,26 @@ class SyncHandler: if batch: state_at_timeline_end = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter + batch.events[-1].event_id, + state_filter=state_filter, + await_full_state=await_full_state, ) ) state_at_timeline_start = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + batch.events[0].event_id, + state_filter=state_filter, + await_full_state=await_full_state, ) ) else: state_at_timeline_end = await self.get_state_at( - room_id, stream_position=now_token, state_filter=state_filter + room_id, + stream_position=now_token, + state_filter=state_filter, + await_full_state=await_full_state, ) state_at_timeline_start = state_at_timeline_end @@ -964,14 +1026,19 @@ class SyncHandler: if batch: state_at_timeline_start = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + batch.events[0].event_id, + state_filter=state_filter, + await_full_state=await_full_state, ) ) else: # We can get here if the user has ignored the senders of all # the recent events. state_at_timeline_start = await self.get_state_at( - room_id, stream_position=now_token, state_filter=state_filter + room_id, + stream_position=now_token, + state_filter=state_filter, + await_full_state=await_full_state, ) # for now, we disable LL for gappy syncs - see @@ -993,20 +1060,28 @@ class SyncHandler: # is indeed the case. assert since_token is not None state_at_previous_sync = await self.get_state_at( - room_id, stream_position=since_token, state_filter=state_filter + room_id, + stream_position=since_token, + state_filter=state_filter, + await_full_state=await_full_state, ) if batch: state_at_timeline_end = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter + batch.events[-1].event_id, + state_filter=state_filter, + await_full_state=await_full_state, ) ) else: # We can get here if the user has ignored the senders of all # the recent events. state_at_timeline_end = await self.get_state_at( - room_id, stream_position=now_token, state_filter=state_filter + room_id, + stream_position=now_token, + state_filter=state_filter, + await_full_state=await_full_state, ) state_ids = _calculate_state( @@ -1036,8 +1111,23 @@ class SyncHandler: (EventTypes.Member, member) for member in members_to_fetch ), + await_full_state=False, ) + # If we only have partial state for the room, `state_ids` may be missing the + # memberships we wanted. We attempt to find some by digging through the auth + # events of timeline events. + if lazy_load_members and await self.store.is_partial_state_room(room_id): + assert members_to_fetch is not None + assert first_event_by_sender_map is not None + + additional_state_ids = ( + await self._find_missing_partial_state_memberships( + room_id, members_to_fetch, first_event_by_sender_map, state_ids + ) + ) + state_ids = {**state_ids, **additional_state_ids} + # At this point, if `lazy_load_members` is enabled, `state_ids` includes # the memberships of all event senders in the timeline. This is because we # may not have sent the memberships in a previous sync. @@ -1086,6 +1176,99 @@ class SyncHandler: if e.type != EventTypes.Aliases # until MSC2261 or alternative solution } + async def _find_missing_partial_state_memberships( + self, + room_id: str, + members_to_fetch: Collection[str], + events_with_membership_auth: Mapping[str, EventBase], + found_state_ids: StateMap[str], + ) -> StateMap[str]: + """Finds missing memberships from a set of auth events and returns them as a + state map. + + Args: + room_id: The partial state room to find the remaining memberships for. + members_to_fetch: The memberships to find. + events_with_membership_auth: A mapping from user IDs to events whose auth + events are known to contain their membership. + found_state_ids: A dict from (type, state_key) -> state_event_id, containing + memberships that have been previously found. Entries in + `members_to_fetch` that have a membership in `found_state_ids` are + ignored. + + Returns: + A dict from ("m.room.member", state_key) -> state_event_id, containing the + memberships missing from `found_state_ids`. + + Raises: + KeyError: if `events_with_membership_auth` does not have an entry for a + missing membership. Memberships in `found_state_ids` do not need an + entry in `events_with_membership_auth`. + """ + additional_state_ids: MutableStateMap[str] = {} + + # Tracks the missing members for logging purposes. + missing_members = set() + + # Identify memberships missing from `found_state_ids` and pick out the auth + # events in which to look for them. + auth_event_ids: Set[str] = set() + for member in members_to_fetch: + if (EventTypes.Member, member) in found_state_ids: + continue + + missing_members.add(member) + event_with_membership_auth = events_with_membership_auth[member] + auth_event_ids.update(event_with_membership_auth.auth_event_ids()) + + auth_events = await self.store.get_events(auth_event_ids) + + # Run through the missing memberships once more, picking out the memberships + # from the pile of auth events we have just fetched. + for member in members_to_fetch: + if (EventTypes.Member, member) in found_state_ids: + continue + + event_with_membership_auth = events_with_membership_auth[member] + + # Dig through the auth events to find the desired membership. + for auth_event_id in event_with_membership_auth.auth_event_ids(): + # We only store events once we have all their auth events, + # so the auth event must be in the pile we have just + # fetched. + auth_event = auth_events[auth_event_id] + + if ( + auth_event.type == EventTypes.Member + and auth_event.state_key == member + ): + missing_members.remove(member) + additional_state_ids[ + (EventTypes.Member, member) + ] = auth_event.event_id + break + + if missing_members: + # There really shouldn't be any missing memberships now. Either: + # * we couldn't find an auth event, which shouldn't happen because we do + # not persist events with persisting their auth events first, or + # * the set of auth events did not contain a membership we wanted, which + # means our caller didn't compute the events in `members_to_fetch` + # correctly, or we somehow accepted an event whose auth events were + # dodgy. + logger.error( + "Failed to find memberships for %s in partial state room " + "%s in the auth events of %s.", + missing_members, + room_id, + [ + events_with_membership_auth[member].event_id + for member in missing_members + ], + ) + + return additional_state_ids + async def unread_notifs_for_room_id( self, room_id: str, sync_config: SyncConfig ) -> NotifCounts: @@ -1730,7 +1913,11 @@ class SyncHandler: continue if room_id in sync_result_builder.joined_room_ids or has_join: - old_state_ids = await self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at( + room_id, + since_token, + state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]), + ) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev = None if old_mem_ev_id: @@ -1756,7 +1943,13 @@ class SyncHandler: newly_left_rooms.append(room_id) else: if not old_state_ids: - old_state_ids = await self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at( + room_id, + since_token, + state_filter=StateFilter.from_types( + [(EventTypes.Member, user_id)] + ), + ) old_mem_ev_id = old_state_ids.get( (EventTypes.Member, user_id), None ) -- cgit 1.5.1 From 5e7847dc923142bc68834f9b9538ada3fdd887d5 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 23 Aug 2022 10:49:59 +0100 Subject: Cache user IDs instead of profile objects (#13573) The profile objects are never used and increase cache size significantly. --- changelog.d/13573.misc | 1 + synapse/handlers/sync.py | 4 +- synapse/state/__init__.py | 13 +++--- synapse/storage/databases/main/roommember.py | 67 ++++++++++++---------------- synapse/util/caches/descriptors.py | 26 ++++++++--- 5 files changed, 57 insertions(+), 54 deletions(-) create mode 100644 changelog.d/13573.misc (limited to 'synapse/handlers/sync.py') diff --git a/changelog.d/13573.misc b/changelog.d/13573.misc new file mode 100644 index 0000000000..1ce9c0c081 --- /dev/null +++ b/changelog.d/13573.misc @@ -0,0 +1 @@ +Cache user IDs instead of profiles to reduce cache memory usage. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b4d3f3958c..2d95b1fa24 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -2421,10 +2421,10 @@ class SyncHandler: joined_room.room_id, joined_room.event_pos.stream ) ) - users_in_room = await self.state.get_current_users_in_room( + user_ids_in_room = await self.state.get_current_user_ids_in_room( joined_room.room_id, extrems ) - if user_id in users_in_room: + if user_id in user_ids_in_room: joined_room_ids.add(joined_room.room_id) return frozenset(joined_room_ids) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index c355e4f98a..3047e1b1ad 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -44,7 +44,6 @@ from synapse.logging.context import ContextResourceUsage from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.roommember import ProfileInfo from synapse.storage.state import StateFilter from synapse.types import StateMap from synapse.util.async_helpers import Linearizer @@ -210,11 +209,11 @@ class StateHandler: ret = await self.resolve_state_groups_for_events(room_id, event_ids) return await ret.get_state(self._state_storage_controller, state_filter) - async def get_current_users_in_room( + async def get_current_user_ids_in_room( self, room_id: str, latest_event_ids: List[str] - ) -> Dict[str, ProfileInfo]: + ) -> Set[str]: """ - Get the users who are currently in a room. + Get the users IDs who are currently in a room. Note: This is much slower than using the equivalent method `DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`, @@ -225,15 +224,15 @@ class StateHandler: room_id: The ID of the room. latest_event_ids: Precomputed list of latest event IDs. Will be computed if None. Returns: - Dictionary of user IDs to their profileinfo. + Set of user IDs in the room. """ assert latest_event_ids is not None - logger.debug("calling resolve_state_groups from get_current_users_in_room") + logger.debug("calling resolve_state_groups from get_current_user_ids_in_room") entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) state = await entry.get_state(self._state_storage_controller, StateFilter.all()) - return await self.store.get_joined_users_from_state(room_id, state, entry) + return await self.store.get_joined_user_ids_from_state(room_id, state, entry) async def get_hosts_in_room_at_events( self, room_id: str, event_ids: Collection[str] diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 827c1f1efd..0eb024a809 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -835,9 +835,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): return shared_room_ids or frozenset() - async def get_joined_users_from_state( + async def get_joined_user_ids_from_state( self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry" - ) -> Dict[str, ProfileInfo]: + ) -> Set[str]: state_group: Union[object, int] = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -848,25 +848,25 @@ class RoomMemberWorkerStore(EventsWorkerStore): assert state_group is not None with Measure(self._clock, "get_joined_users_from_state"): - return await self._get_joined_users_from_context( + return await self._get_joined_user_ids_from_context( room_id, state_group, state, context=state_entry ) @cached(num_args=2, iterable=True, max_entries=100000) - async def _get_joined_users_from_context( + async def _get_joined_user_ids_from_context( self, room_id: str, state_group: Union[object, int], current_state_ids: StateMap[str], event: Optional[EventBase] = None, context: Optional["_StateCacheEntry"] = None, - ) -> Dict[str, ProfileInfo]: + ) -> Set[str]: # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states # with a state_group of None are likely to be different. assert state_group is not None - users_in_room = {} + users_in_room = set() member_event_ids = [ e_id for key, e_id in current_state_ids.items() @@ -879,11 +879,11 @@ class RoomMemberWorkerStore(EventsWorkerStore): # If we do then we can reuse that result and simply update it with # any membership changes in `delta_ids` if context.prev_group and context.delta_ids: - prev_res = self._get_joined_users_from_context.cache.get_immediate( + prev_res = self._get_joined_user_ids_from_context.cache.get_immediate( (room_id, context.prev_group), None ) - if prev_res and isinstance(prev_res, dict): - users_in_room = dict(prev_res) + if prev_res and isinstance(prev_res, set): + users_in_room = prev_res member_event_ids = [ e_id for key, e_id in context.delta_ids.items() @@ -891,7 +891,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ] for etype, state_key in context.delta_ids: if etype == EventTypes.Member: - users_in_room.pop(state_key, None) + users_in_room.discard(state_key) # We check if we have any of the member event ids in the event cache # before we ask the DB @@ -908,42 +908,41 @@ class RoomMemberWorkerStore(EventsWorkerStore): ev_entry = event_map.get(event_id) if ev_entry and not ev_entry.event.rejected_reason: if ev_entry.event.membership == Membership.JOIN: - users_in_room[ev_entry.event.state_key] = ProfileInfo( - display_name=ev_entry.event.content.get("displayname", None), - avatar_url=ev_entry.event.content.get("avatar_url", None), - ) + users_in_room.add(ev_entry.event.state_key) else: missing_member_event_ids.append(event_id) if missing_member_event_ids: - event_to_memberships = await self._get_joined_profiles_from_event_ids( + event_to_memberships = await self._get_user_ids_from_membership_event_ids( missing_member_event_ids ) - users_in_room.update(row for row in event_to_memberships.values() if row) + users_in_room.update(event_to_memberships.values()) if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: if event.event_id in member_event_ids: - users_in_room[event.state_key] = ProfileInfo( - display_name=event.content.get("displayname", None), - avatar_url=event.content.get("avatar_url", None), - ) + users_in_room.add(event.state_key) return users_in_room - @cached(max_entries=10000) - def _get_joined_profile_from_event_id( + @cached( + max_entries=10000, + # This name matches the old function that has been replaced - the cache name + # is kept here to maintain backwards compatibility. + name="_get_joined_profile_from_event_id", + ) + def _get_user_id_from_membership_event_id( self, event_id: str ) -> Optional[Tuple[str, ProfileInfo]]: raise NotImplementedError() @cachedList( - cached_method_name="_get_joined_profile_from_event_id", + cached_method_name="_get_user_id_from_membership_event_id", list_name="event_ids", ) - async def _get_joined_profiles_from_event_ids( + async def _get_user_ids_from_membership_event_ids( self, event_ids: Iterable[str] - ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]: + ) -> Dict[str, str]: """For given set of member event_ids check if they point to a join event and if so return the associated user and profile info. @@ -958,21 +957,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): table="room_memberships", column="event_id", iterable=event_ids, - retcols=("user_id", "display_name", "avatar_url", "event_id"), + retcols=("user_id", "event_id"), keyvalues={"membership": Membership.JOIN}, batch_size=1000, - desc="_get_joined_profiles_from_event_ids", + desc="_get_user_ids_from_membership_event_ids", ) - return { - row["event_id"]: ( - row["user_id"], - ProfileInfo( - avatar_url=row["avatar_url"], display_name=row["display_name"] - ), - ) - for row in rows - } + return {row["event_id"]: row["user_id"] for row in rows} @cached(max_entries=10000) async def is_host_joined(self, room_id: str, host: str) -> bool: @@ -1131,12 +1122,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): else: # The cache doesn't match the state group or prev state group, # so we calculate the result from first principles. - joined_users = await self.get_joined_users_from_state( + joined_user_ids = await self.get_joined_user_ids_from_state( room_id, state, state_entry ) cache.hosts_to_joined_users = {} - for user_id in joined_users: + for user_id in joined_user_ids: host = intern_string(get_domain_from_id(user_id)) cache.hosts_to_joined_users.setdefault(host, set()).add(user_id) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 867f315b2a..9d4bc89edb 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -73,8 +73,10 @@ class _CacheDescriptorBase: num_args: Optional[int], uncached_args: Optional[Collection[str]] = None, cache_context: bool = False, + name: Optional[str] = None, ): self.orig = orig + self.name = name or orig.__name__ arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args @@ -211,7 +213,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: cache: LruCache[CacheKey, Any] = LruCache( - cache_name=self.orig.__name__, + cache_name=self.name, max_size=self.max_entries, ) @@ -241,7 +243,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): wrapped = cast(_CachedFunction, _wrapped) wrapped.cache = cache - obj.__dict__[self.orig.__name__] = wrapped + obj.__dict__[self.name] = wrapped return wrapped @@ -301,12 +303,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, + name: Optional[str] = None, ): super().__init__( orig, num_args=num_args, uncached_args=uncached_args, cache_context=cache_context, + name=name, ) if tree and self.num_args < 2: @@ -321,7 +325,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: cache: DeferredCache[CacheKey, Any] = DeferredCache( - name=self.orig.__name__, + name=self.name, max_entries=self.max_entries, tree=self.tree, iterable=self.iterable, @@ -372,7 +376,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): wrapped.cache = cache wrapped.num_args = self.num_args - obj.__dict__[self.orig.__name__] = wrapped + obj.__dict__[self.name] = wrapped return wrapped @@ -393,6 +397,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): cached_method_name: str, list_name: str, num_args: Optional[int] = None, + name: Optional[str] = None, ): """ Args: @@ -403,7 +408,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): but including list_name) to use as cache keys. Defaults to all named args of the function. """ - super().__init__(orig, num_args=num_args, uncached_args=None) + super().__init__(orig, num_args=num_args, uncached_args=None, name=name) self.list_name = list_name @@ -525,7 +530,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): else: return defer.succeed(results) - obj.__dict__[self.orig.__name__] = wrapped + obj.__dict__[self.name] = wrapped return wrapped @@ -577,6 +582,7 @@ def cached( cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, + name: Optional[str] = None, ) -> Callable[[F], _CachedFunction[F]]: func = lambda orig: DeferredCacheDescriptor( orig, @@ -587,13 +593,18 @@ def cached( cache_context=cache_context, iterable=iterable, prune_unread_entries=prune_unread_entries, + name=name, ) return cast(Callable[[F], _CachedFunction[F]], func) def cachedList( - *, cached_method_name: str, list_name: str, num_args: Optional[int] = None + *, + cached_method_name: str, + list_name: str, + num_args: Optional[int] = None, + name: Optional[str] = None, ) -> Callable[[F], _CachedFunction[F]]: """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. @@ -628,6 +639,7 @@ def cachedList( cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, + name=name, ) return cast(Callable[[F], _CachedFunction[F]], func) -- cgit 1.5.1