From 495cb100d127212d55a46c177706d732950e70be Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Aug 2018 14:46:17 +0100 Subject: Allow profile changes to happen on workers --- synapse/storage/profile.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 60295da254..246ab836bd 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -71,8 +71,6 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_from_remote_profile_cache", ) - -class ProfileStore(ProfileWorkerStore): def create_profile(self, user_localpart): return self._simple_insert( table="profiles", @@ -182,3 +180,7 @@ class ProfileStore(ProfileWorkerStore): if res: defer.returnValue(True) + + +class ProfileStore(ProfileWorkerStore): + pass -- cgit 1.5.1 From cd9765805e2792ad7d6b0b014447f51d2b9add8c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 7 Aug 2018 10:48:31 +0100 Subject: Allow ratelimiting on workers --- synapse/storage/room.py | 58 ++++++++++++++++++++++++------------------------- 1 file changed, 29 insertions(+), 29 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 3147fb6827..1e473cc188 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -170,6 +170,35 @@ class RoomWorkerStore(SQLBaseStore): desc="is_room_blocked", ) + @cachedInlineCallbacks(max_entries=10000) + def get_ratelimit_for_user(self, user_id): + """Check if there are any overrides for ratelimiting for the given + user + + Args: + user_id (str) + + Returns: + RatelimitOverride if there is an override, else None. If the contents + of RatelimitOverride are None or 0 then ratelimitng has been + disabled for that user entirely. + """ + row = yield self._simple_select_one( + table="ratelimit_override", + keyvalues={"user_id": user_id}, + retcols=("messages_per_second", "burst_count"), + allow_none=True, + desc="get_ratelimit_for_user", + ) + + if row: + defer.returnValue(RatelimitOverride( + messages_per_second=row["messages_per_second"], + burst_count=row["burst_count"], + )) + else: + defer.returnValue(None) + class RoomStore(RoomWorkerStore, SearchStore): @@ -469,35 +498,6 @@ class RoomStore(RoomWorkerStore, SearchStore): "get_all_new_public_rooms", get_all_new_public_rooms ) - @cachedInlineCallbacks(max_entries=10000) - def get_ratelimit_for_user(self, user_id): - """Check if there are any overrides for ratelimiting for the given - user - - Args: - user_id (str) - - Returns: - RatelimitOverride if there is an override, else None. If the contents - of RatelimitOverride are None or 0 then ratelimitng has been - disabled for that user entirely. - """ - row = yield self._simple_select_one( - table="ratelimit_override", - keyvalues={"user_id": user_id}, - retcols=("messages_per_second", "burst_count"), - allow_none=True, - desc="get_ratelimit_for_user", - ) - - if row: - defer.returnValue(RatelimitOverride( - messages_per_second=row["messages_per_second"], - burst_count=row["burst_count"], - )) - else: - defer.returnValue(None) - @defer.inlineCallbacks def block_room(self, room_id, user_id): yield self._simple_insert( -- cgit 1.5.1 From ce6db0e5473e31292567b11599eb334c3275c564 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 8 Aug 2018 17:01:57 +0100 Subject: Choose state algorithm based on room version --- synapse/handlers/federation.py | 8 +++- synapse/handlers/room_member.py | 5 +- synapse/state/__init__.py | 104 +++++++++++++++++++++++++++++++++++----- synapse/storage/events.py | 4 +- 4 files changed, 105 insertions(+), 16 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0dffd44e22..75a819dd11 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -274,8 +274,9 @@ class FederationHandler(BaseHandler): ev_ids, get_prev_content=False, check_redacted=False ) + room_version = yield self.store.get_room_version(pdu.room_id) state_map = yield resolve_events_with_factory( - state_groups, {pdu.event_id: pdu}, fetch + room_version, state_groups, {pdu.event_id: pdu}, fetch ) state = (yield self.store.get_events(state_map.values())).values() @@ -1811,7 +1812,10 @@ class FederationHandler(BaseHandler): (d.type, d.state_key): d for d in different_events if d }) - new_state = self.state_handler.resolve_events( + room_version = yield self.store.get_room_version(event.room_id) + + new_state = yield self.state_handler.resolve_events( + room_version, [list(local_view.values()), list(remote_view.values())], event ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0d4a3f4677..b6010bb41e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -341,9 +341,10 @@ class RoomMemberHandler(object): prev_events_and_hashes = yield self.store.get_prev_events_for_room( room_id, ) - latest_event_ids = ( + latest_event_ids = [ event_id for (event_id, _, _) in prev_events_and_hashes - ) + ] + current_state_ids = yield self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids, ) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 8c091d07c9..222daa0b28 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -23,16 +23,15 @@ from frozendict import frozendict from twisted.internet import defer -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, RoomVersions from synapse.events.snapshot import EventContext +from synapse.state import v1 from synapse.util.async import Linearizer from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.util.metrics import Measure -from .v1 import resolve_events_with_factory, resolve_events_with_state_map - logger = logging.getLogger(__name__) @@ -263,8 +262,14 @@ class StateHandler(object): defer.returnValue(context) logger.debug("calling resolve_state_groups from compute_event_context") + if event.type == EventTypes.Create: + room_version = event.content.get("room_version", RoomVersions.V1) + else: + room_version = None + entry = yield self.resolve_state_groups_for_events( event.room_id, [e for e, _ in event.prev_events], + explicit_room_version=room_version, ) prev_state_ids = entry.state @@ -332,13 +337,17 @@ class StateHandler(object): defer.returnValue(context) @defer.inlineCallbacks - def resolve_state_groups_for_events(self, room_id, event_ids): + def resolve_state_groups_for_events(self, room_id, event_ids, + explicit_room_version=None): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. Args: - room_id (str): - event_ids (list[str]): + room_id (str) + event_ids (list[str]) + explicit_room_version (str|None): If set uses the the given room + version to choose the resolution algorithm. If None, then + checks the database for room version. Returns: Deferred[_StateCacheEntry]: resolved state @@ -364,8 +373,13 @@ class StateHandler(object): delta_ids=delta_ids, )) + room_version = explicit_room_version + if not room_version: + room_version = yield self.store.get_room_version(room_id) + result = yield self._state_resolution_handler.resolve_state_groups( - room_id, state_groups_ids, None, self._state_map_factory, + room_id, room_version, state_groups_ids, None, + self._state_map_factory, ) defer.returnValue(result) @@ -374,7 +388,8 @@ class StateHandler(object): ev_ids, get_prev_content=False, check_redacted=False, ) - def resolve_events(self, state_sets, event): + @defer.inlineCallbacks + def resolve_events(self, room_version, state_sets, event): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) @@ -389,14 +404,18 @@ class StateHandler(object): for ev in st } + room_version = yield self.store.get_room_version(event.room_id) + with Measure(self.clock, "state._resolve_events"): - new_state = resolve_events_with_state_map(state_set_ids, state_map) + new_state = resolve_events_with_state_map( + room_version, state_set_ids, state_map, + ) new_state = { key: state_map[ev_id] for key, ev_id in iteritems(new_state) } - return new_state + defer.returnValue(new_state) class StateResolutionHandler(object): @@ -429,7 +448,7 @@ class StateResolutionHandler(object): @defer.inlineCallbacks @log_function def resolve_state_groups( - self, room_id, state_groups_ids, event_map, state_map_factory, + self, room_id, room_version, state_groups_ids, event_map, state_map_factory, ): """Resolves conflicts between a set of state groups @@ -438,6 +457,7 @@ class StateResolutionHandler(object): Args: room_id (str): room we are resolving for (used for logging) + room_version (str): version of the room state_groups_ids (dict[int, dict[(str, str), str]]): map from state group id to the state in that state group (where 'state' is a map from state key to event id) @@ -491,6 +511,7 @@ class StateResolutionHandler(object): logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_factory( + room_version, list(itervalues(state_groups_ids)), event_map=event_map, state_map_factory=state_map_factory, @@ -572,3 +593,64 @@ def _make_state_cache_entry( prev_group=prev_group, delta_ids=delta_ids, ) + + +def resolve_events_with_state_map(room_version, state_sets, state_map): + """ + Args: + room_version(str): Version of the room + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + state_map(dict): a dict from event_id to event, for all events in + state_sets. + + Returns + dict[(str, str), str]: + a map from (type, state_key) to event_id. + """ + if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,): + return v1.resolve_events_with_state_map( + state_sets, state_map, + ) + else: + # This should only happen if we added a version but forgot to add it to + # the list above. + raise Exception( + "No state resolution algorithm defined for version %r" % (room_version,) + ) + + +def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory): + """ + Args: + room_version(str): Version of the room + + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + + event_map(dict[str,FrozenEvent]|None): + a dict from event_id to event, for any events that we happen to + have in flight (eg, those currently being persisted). This will be + used as a starting point fof finding the state we need; any missing + events will be requested via state_map_factory. + + If None, all events will be fetched via state_map_factory. + + state_map_factory(func): will be called + with a list of event_ids that are needed, and should return with + a Deferred of dict of event_id to event. + + Returns + Deferred[dict[(str, str), str]]: + a map from (type, state_key) to event_id. + """ + if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,): + return v1.resolve_events_with_factory( + state_sets, event_map, state_map_factory, + ) + else: + # This should only happen if we added a version but forgot to add it to + # the list above. + raise Exception( + "No state resolution algorithm defined for version %r" % (room_version,) + ) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ce32e8fefd..dc68d365c2 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -705,9 +705,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore } events_map = {ev.event_id: ev for ev, _ in events_context} + room_version = yield self.get_room_version(room_id) + logger.debug("calling resolve_state_groups from preserve_events") res = yield self._state_resolution_handler.resolve_state_groups( - room_id, state_groups, events_map, get_events + room_id, room_version, state_groups, events_map, get_events ) defer.returnValue((res.state, None)) -- cgit 1.5.1 From 38f708a2bbee13ebd34f1595910dd0e4fd532818 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 17 Aug 2018 11:37:42 +0100 Subject: Remote profile cache should remain in master worker --- synapse/storage/profile.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 246ab836bd..88b50f33b5 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -94,6 +94,8 @@ class ProfileWorkerStore(SQLBaseStore): desc="set_profile_avatar_url", ) + +class ProfileStore(ProfileWorkerStore): def add_remote_profile_cache(self, user_id, displayname, avatar_url): """Ensure we are caching the remote user's profiles. @@ -180,7 +182,3 @@ class ProfileWorkerStore(SQLBaseStore): if res: defer.returnValue(True) - - -class ProfileStore(ProfileWorkerStore): - pass -- cgit 1.5.1 From bb81e78ec6c05edc95b25a954bdf1cf688d4d652 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Wed, 22 Aug 2018 00:56:37 +0200 Subject: Split the state_group_cache in two (#3726) Splits the state_group_cache in two. One half contains normal state events; the other contains member events. The idea is that the lazyloading common case of: "I want a subset of member events plus all of the other state" can be accomplished efficiently by splitting the cache into two, and asking for "all events" from the non-members cache, and "just these keys" from the members cache. This means we can avoid having to make DictionaryCache aware of these sort of complicated queries, whilst letting LL requests benefit from the caching. Previously we were unable to sensibly use the caching and had to pull all state from the DB irrespective of the filtering, which made things slow. Hopefully fixes https://github.com/matrix-org/synapse/issues/3720. --- changelog.d/3726.misc | 1 + synapse/storage/state.py | 158 +++++++++++++++++++++++++++++++++++++++----- tests/storage/test_state.py | 105 ++++++++++++++++++++++++++--- 3 files changed, 236 insertions(+), 28 deletions(-) create mode 100644 changelog.d/3726.misc (limited to 'synapse/storage') diff --git a/changelog.d/3726.misc b/changelog.d/3726.misc new file mode 100644 index 0000000000..c4f66ec998 --- /dev/null +++ b/changelog.d/3726.misc @@ -0,0 +1 @@ +Split the state_group_cache into member and non-member state events (and so speed up LL /sync) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index dd03c4168b..4b971efdba 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -60,8 +60,43 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def __init__(self, db_conn, hs): super(StateGroupWorkerStore, self).__init__(db_conn, hs) + # Originally the state store used a single DictionaryCache to cache the + # event IDs for the state types in a given state group to avoid hammering + # on the state_group* tables. + # + # The point of using a DictionaryCache is that it can cache a subset + # of the state events for a given state group (i.e. a subset of the keys for a + # given dict which is an entry in the cache for a given state group ID). + # + # However, this poses problems when performing complicated queries + # on the store - for instance: "give me all the state for this group, but + # limit members to this subset of users", as DictionaryCache's API isn't + # rich enough to say "please cache any of these fields, apart from this subset". + # This is problematic when lazy loading members, which requires this behaviour, + # as without it the cache has no choice but to speculatively load all + # state events for the group, which negates the efficiency being sought. + # + # Rather than overcomplicating DictionaryCache's API, we instead split the + # state_group_cache into two halves - one for tracking non-member events, + # and the other for tracking member_events. This means that lazy loading + # queries can be made in a cache-friendly manner by querying both caches + # separately and then merging the result. So for the example above, you + # would query the members cache for a specific subset of state keys + # (which DictionaryCache will handle efficiently and fine) and the non-members + # cache for all state (which DictionaryCache will similarly handle fine) + # and then just merge the results together. + # + # We size the non-members cache to be smaller than the members cache as the + # vast majority of state in Matrix (today) is member events. + self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache") + "*stateGroupCache*", + # TODO: this hasn't been tuned yet + 50000 * get_cache_factor_for("stateGroupCache") + ) + self._state_group_members_cache = DictionaryCache( + "*stateGroupMembersCache*", + 500000 * get_cache_factor_for("stateGroupMembersCache") ) @defer.inlineCallbacks @@ -275,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): }) @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, types): + def _get_state_groups_from_groups(self, groups, types, members=None): """Returns the state groups for a given set of groups, filtering on types of state events. @@ -284,6 +319,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): types (Iterable[str, str|None]|None): list of 2-tuples of the form (`type`, `state_key`), where a `state_key` of `None` matches all state_keys for the `type`. If None, all types are returned. + members (bool|None): If not None, then, in addition to any filtering + implied by types, the results are also filtered to only include + member events (if True), or to exclude member events (if False) Returns: dictionary state_group -> (dict of (type, state_key) -> event id) @@ -294,14 +332,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): for chunk in chunks: res = yield self.runInteraction( "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, chunk, types, + self._get_state_groups_from_groups_txn, chunk, types, members, ) results.update(res) defer.returnValue(results) def _get_state_groups_from_groups_txn( - self, txn, groups, types=None, + self, txn, groups, types=None, members=None, ): results = {group: {} for group in groups} @@ -339,6 +377,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): %s """) + if members is True: + sql += " AND type = '%s'" % (EventTypes.Member,) + elif members is False: + sql += " AND type <> '%s'" % (EventTypes.Member,) + # Turns out that postgres doesn't like doing a list of OR's and # is about 1000x slower, so we just issue a query for each specific # type seperately. @@ -386,6 +429,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): else: where_clause = "" + if members is True: + where_clause += " AND type = '%s'" % EventTypes.Member + elif members is False: + where_clause += " AND type <> '%s'" % EventTypes.Member + # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) for group in groups: @@ -580,10 +628,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) - def _get_some_state_from_cache(self, group, types, filtered_types=None): + def _get_some_state_from_cache(self, cache, group, types, filtered_types=None): """Checks if group is in cache. See `_get_state_for_groups` Args: + cache(DictionaryCache): the state group cache to use group(int): The state group to lookup types(list[str, str|None]): List of 2-tuples of the form (`type`, `state_key`), where a `state_key` of `None` matches all @@ -597,11 +646,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): requests state from the cache, if False we need to query the DB for the missing state. """ - is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) + is_all, known_absent, state_dict_ids = cache.get(group) type_to_key = {} - # tracks whether any of ourrequested types are missing from the cache + # tracks whether any of our requested types are missing from the cache missing_types = False for typ, state_key in types: @@ -648,7 +697,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if include(k[0], k[1]) }, got_all - def _get_all_state_from_cache(self, group): + def _get_all_state_from_cache(self, cache, group): """Checks if group is in cache. See `_get_state_for_groups` Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool @@ -656,9 +705,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): cache, if False we need to query the DB for the missing state. Args: + cache(DictionaryCache): the state group cache to use group: The state group to lookup """ - is_all, _, state_dict_ids = self._state_group_cache.get(group) + is_all, _, state_dict_ids = cache.get(group) return state_dict_ids, is_all @@ -681,6 +731,62 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): list of event types. Other types of events are returned unfiltered. If None, `types` filtering is applied to all events. + Returns: + Deferred[dict[int, dict[(type, state_key), EventBase]]] + a dictionary mapping from state group to state dictionary. + """ + if types is not None: + non_member_types = [t for t in types if t[0] != EventTypes.Member] + + if filtered_types is not None and EventTypes.Member not in filtered_types: + # we want all of the membership events + member_types = None + else: + member_types = [t for t in types if t[0] == EventTypes.Member] + + else: + non_member_types = None + member_types = None + + non_member_state = yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, non_member_types, filtered_types, + ) + # XXX: we could skip this entirely if member_types is [] + member_state = yield self._get_state_for_groups_using_cache( + # we set filtered_types=None as member_state only ever contain members. + groups, self._state_group_members_cache, member_types, None, + ) + + state = non_member_state + for group in groups: + state[group].update(member_state[group]) + + defer.returnValue(state) + + @defer.inlineCallbacks + def _get_state_for_groups_using_cache( + self, groups, cache, types=None, filtered_types=None + ): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key, querying from a specific cache. + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + cache (DictionaryCache): the cache of group ids to state dicts which + we will pass through - either the normal state cache or the specific + members state cache. + types (None|iterable[(str, None|str)]): + indicates the state type/keys required. If None, the whole + state is fetched and returned. + + Otherwise, each entry should be a `(type, state_key)` tuple to + include in the response. A `state_key` of None is a wildcard + meaning that we require all state with that type. + filtered_types(list[str]|None): Only apply filtering via `types` to this + list of event types. Other types of events are returned unfiltered. + If None, `types` filtering is applied to all events. + Returns: Deferred[dict[int, dict[(type, state_key), EventBase]]] a dictionary mapping from state group to state dictionary. @@ -692,7 +798,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if types is not None: for group in set(groups): state_dict_ids, got_all = self._get_some_state_from_cache( - group, types, filtered_types + cache, group, types, filtered_types ) results[group] = state_dict_ids @@ -701,7 +807,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): else: for group in set(groups): state_dict_ids, got_all = self._get_all_state_from_cache( - group + cache, group ) results[group] = state_dict_ids @@ -710,8 +816,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): missing_groups.append(group) if missing_groups: - # Okay, so we have some missing_types, lets fetch them. - cache_seq_num = self._state_group_cache.sequence + # Okay, so we have some missing_types, let's fetch them. + cache_seq_num = cache.sequence # the DictionaryCache knows if it has *all* the state, but # does not know if it has all of the keys of a particular type, @@ -725,7 +831,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): types_to_fetch = types group_to_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types_to_fetch + missing_groups, types_to_fetch, cache == self._state_group_members_cache, ) for group, group_state_dict in iteritems(group_to_state_dict): @@ -745,7 +851,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # update the cache with all the things we fetched from the # database. - self._state_group_cache.update( + cache.update( cache_seq_num, key=group, value=group_state_dict, @@ -847,15 +953,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ], ) - # Prefill the state group cache with this group. + # Prefill the state group caches with this group. # It's fine to use the sequence like this as the state group map # is immutable. (If the map wasn't immutable then this prefill could # race with another update) + + current_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] == EventTypes.Member + } + txn.call_after( + self._state_group_members_cache.update, + self._state_group_members_cache.sequence, + key=state_group, + value=dict(current_member_state_ids), + ) + + current_non_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] != EventTypes.Member + } txn.call_after( self._state_group_cache.update, self._state_group_cache.sequence, key=state_group, - value=dict(current_state_ids), + value=dict(current_non_member_state_ids), ) return state_group diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index ebfd969b36..d717b9f94e 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -185,6 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters out members with types=[] (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] ) @@ -197,8 +198,20 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, [], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, True) + self.assertDictEqual( + {}, + state_dict, + ) + # test _get_some_state_from_cache correctly filters in members with wildcard types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] ) @@ -207,6 +220,18 @@ class StateStoreTestCase(tests.unittest.TestCase): { (e1.type, e1.state_key): e1.event_id, (e2.type, e2.state_key): e2.event_id, + }, + state_dict, + ) + + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, True) + self.assertDictEqual( + { (e3.type, e3.state_key): e3.event_id, # e4 is overwritten by e5 (e5.type, e5.state_key): e5.event_id, @@ -216,6 +241,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters in members with specific types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member], @@ -226,6 +252,20 @@ class StateStoreTestCase(tests.unittest.TestCase): { (e1.type, e1.state_key): e1.event_id, (e2.type, e2.state_key): e2.event_id, + }, + state_dict, + ) + + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, + [(EventTypes.Member, e5.state_key)], + filtered_types=[EventTypes.Member], + ) + + self.assertEqual(is_all, True) + self.assertDictEqual( + { (e5.type, e5.state_key): e5.event_id, }, state_dict, @@ -234,6 +274,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters in members with specific types # and no filtered_types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, group, [(EventTypes.Member, e5.state_key)], filtered_types=None ) @@ -254,9 +295,6 @@ class StateStoreTestCase(tests.unittest.TestCase): { (e1.type, e1.state_key): e1.event_id, (e2.type, e2.state_key): e2.event_id, - (e3.type, e3.state_key): e3.event_id, - # e4 is overwritten by e5 - (e5.type, e5.state_key): e5.event_id, }, ) @@ -269,8 +307,6 @@ class StateStoreTestCase(tests.unittest.TestCase): # list fetched keys so it knows it's partial fetched_keys=( (e1.type, e1.state_key), - (e3.type, e3.state_key), - (e5.type, e5.state_key), ), ) @@ -284,8 +320,6 @@ class StateStoreTestCase(tests.unittest.TestCase): set( [ (e1.type, e1.state_key), - (e3.type, e3.state_key), - (e5.type, e5.state_key), ] ), ) @@ -293,8 +327,6 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict_ids, { (e1.type, e1.state_key): e1.event_id, - (e3.type, e3.state_key): e3.event_id, - (e5.type, e5.state_key): e5.event_id, }, ) @@ -304,14 +336,25 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters out members with types=[] room_id = self.room.to_string() (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] ) self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) + room_id = self.room.to_string() + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, [], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, True) + self.assertDictEqual({}, state_dict) + # test _get_some_state_from_cache correctly filters in members wildcard types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] ) @@ -319,8 +362,19 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertDictEqual( { (e1.type, e1.state_key): e1.event_id, + }, + state_dict, + ) + + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + ) + + self.assertEqual(is_all, True) + self.assertDictEqual( + { (e3.type, e3.state_key): e3.event_id, - # e4 is overwritten by e5 (e5.type, e5.state_key): e5.event_id, }, state_dict, @@ -328,6 +382,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters in members with specific types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member], @@ -337,6 +392,20 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertDictEqual( { (e1.type, e1.state_key): e1.event_id, + }, + state_dict, + ) + + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, + group, + [(EventTypes.Member, e5.state_key)], + filtered_types=[EventTypes.Member], + ) + + self.assertEqual(is_all, True) + self.assertDictEqual( + { (e5.type, e5.state_key): e5.event_id, }, state_dict, @@ -345,8 +414,22 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_some_state_from_cache correctly filters in members with specific types # and no filtered_types (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_cache, + group, [(EventTypes.Member, e5.state_key)], filtered_types=None + ) + + self.assertEqual(is_all, False) + self.assertDictEqual({}, state_dict) + + (state_dict, is_all) = yield self.store._get_some_state_from_cache( + self.store._state_group_members_cache, group, [(EventTypes.Member, e5.state_key)], filtered_types=None ) self.assertEqual(is_all, True) - self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) + self.assertDictEqual( + { + (e5.type, e5.state_key): e5.event_id, + }, + state_dict, + ) -- cgit 1.5.1