diff options
author | Olivier Wilkinson (reivilibre) <olivier@librepush.net> | 2021-08-04 15:01:17 +0100 |
---|---|---|
committer | Olivier Wilkinson (reivilibre) <olivier@librepush.net> | 2021-08-04 15:06:06 +0100 |
commit | 5fa9110c24138712029335e77434836c2c4c25da (patch) | |
tree | 3a2cb5d57c4ecc6da0a33c9eca0d5ab9c3f0dbef | |
parent | Remove _get_state_groups_from_groups_txn (diff) | |
download | synapse-5fa9110c24138712029335e77434836c2c4c25da.tar.xz |
Make StateFilter frozen
-rw-r--r-- | synapse/storage/databases/state/store.py | 15 | ||||
-rw-r--r-- | synapse/storage/state.py | 27 | ||||
-rw-r--r-- | tests/storage/test_state.py | 43 |
3 files changed, 44 insertions, 41 deletions
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 7119323ed4..e4b47ff8e0 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -16,8 +16,6 @@ import logging from collections import namedtuple from typing import Dict, Iterable, List, Optional, Set, Tuple -from frozendict import frozendict - from synapse.api.constants import EventTypes from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -188,19 +186,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): state map """ - # convert the state_filter.types dict into something that is hashable. - frozen_kvs = {} - for k, v in state_filter.types.items(): - if v is None: - frozen_kvs[k] = v - else: - # make the set hashable by making a frozen copy of it - frozen_kvs[k] = frozenset(v) - - state_filter_hashable = (frozendict(frozen_kvs), state_filter.include_others) - return await self._state_group_from_group_cache.wrap( - (group, state_filter_hashable), + (group, state_filter), self.db_pool.runInteraction, "_get_state_groups_from_group", self._get_state_groups_from_group_txn, diff --git a/synapse/storage/state.py b/synapse/storage/state.py index e5400d681a..f23082f1df 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -25,6 +25,7 @@ from typing import ( ) import attr +from frozendict import frozendict from synapse.api.constants import EventTypes from synapse.events import EventBase @@ -40,7 +41,7 @@ logger = logging.getLogger(__name__) T = TypeVar("T") -@attr.s(slots=True) +@attr.s(slots=True, frozen=True) class StateFilter: """A filter used when querying for state. @@ -53,14 +54,16 @@ class StateFilter: appear in `types`. """ - types = attr.ib(type=Dict[str, Optional[Set[str]]]) + types = attr.ib(type=frozendict[str, Optional[Set[str]]]) include_others = attr.ib(default=False, type=bool) def __attrs_post_init__(self): # If `include_others` is set we canonicalise the filter by removing # wildcards from the types dictionary if self.include_others: - self.types = {k: v for k, v in self.types.items() if v is not None} + self.types = frozendict( + {k: v for k, v in self.types.items() if v is not None} + ) @staticmethod def all() -> "StateFilter": @@ -69,7 +72,7 @@ class StateFilter: Returns: The new state filter. """ - return StateFilter(types={}, include_others=True) + return StateFilter(types=frozendict(), include_others=True) @staticmethod def none() -> "StateFilter": @@ -78,7 +81,7 @@ class StateFilter: Returns: The new state filter. """ - return StateFilter(types={}, include_others=False) + return StateFilter(types=frozendict(), include_others=False) @staticmethod def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": @@ -103,7 +106,7 @@ class StateFilter: type_dict.setdefault(typ, set()).add(s) # type: ignore - return StateFilter(types=type_dict) + return StateFilter(types=frozendict(type_dict)) @staticmethod def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": @@ -116,7 +119,9 @@ class StateFilter: Returns: The new state filter """ - return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) + return StateFilter( + types=frozendict({EventTypes.Member: set(members)}), include_others=True + ) def return_expanded(self) -> "StateFilter": """Creates a new StateFilter where type wild cards have been removed @@ -173,7 +178,7 @@ class StateFilter: # We want to return all non-members, but only particular # memberships return StateFilter( - types={EventTypes.Member: self.types[EventTypes.Member]}, + types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), include_others=True, ) @@ -324,14 +329,16 @@ class StateFilter: if state_keys is None: member_filter = StateFilter.all() else: - member_filter = StateFilter({EventTypes.Member: state_keys}) + member_filter = StateFilter(frozendict({EventTypes.Member: state_keys})) elif self.include_others: member_filter = StateFilter.all() else: member_filter = StateFilter.none() non_member_filter = StateFilter( - types={k: v for k, v in self.types.items() if k != EventTypes.Member}, + types=frozendict( + {k: v for k, v in self.types.items() if k != EventTypes.Member} + ), include_others=self.include_others, ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 8695264595..d5e9e850a9 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -14,6 +14,8 @@ import logging +from frozendict import frozendict + from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.storage.state import StateFilter @@ -183,7 +185,7 @@ class StateStoreTestCase(HomeserverTestCase): self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( - types={EventTypes.Member: {self.u_alice.to_string()}}, + types=frozendict({EventTypes.Member: {self.u_alice.to_string()}}), include_others=True, ), ) @@ -203,7 +205,7 @@ class StateStoreTestCase(HomeserverTestCase): self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True + types=frozendict({EventTypes.Member: set()}), include_others=True ), ) ) @@ -228,7 +230,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True + types=frozendict({EventTypes.Member: set()}), include_others=True ), ) @@ -245,7 +247,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True + types=frozendict({EventTypes.Member: set()}), include_others=True ), ) @@ -258,7 +260,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, include_others=True + types=frozendict({EventTypes.Member: None}), include_others=True ), ) @@ -275,7 +277,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, include_others=True + types=frozendict({EventTypes.Member: None}), include_others=True ), ) @@ -295,7 +297,8 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, include_others=True + types=frozendict({EventTypes.Member: {e5.state_key}}), + include_others=True, ), ) @@ -312,7 +315,8 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, include_others=True + types=frozendict({EventTypes.Member: {e5.state_key}}), + include_others=True, ), ) @@ -325,7 +329,8 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, include_others=False + types=frozendict({EventTypes.Member: {e5.state_key}}), + include_others=False, ), ) @@ -375,7 +380,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True + types=frozendict({EventTypes.Member: set()}), include_others=True ), ) @@ -387,7 +392,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True + types=frozendict({EventTypes.Member: set()}), include_others=True ), ) @@ -400,7 +405,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, include_others=True + types=frozendict({EventTypes.Member: None}), include_others=True ), ) @@ -411,7 +416,7 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, include_others=True + types=frozendict({EventTypes.Member: None}), include_others=True ), ) @@ -430,7 +435,8 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, include_others=True + types=frozendict({EventTypes.Member: {e5.state_key}}), + include_others=True, ), ) @@ -441,7 +447,8 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, include_others=True + types=frozendict({EventTypes.Member: {e5.state_key}}), + include_others=True, ), ) @@ -454,7 +461,8 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, include_others=False + types=frozendict({EventTypes.Member: {e5.state_key}}), + include_others=False, ), ) @@ -465,7 +473,8 @@ class StateStoreTestCase(HomeserverTestCase): self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, include_others=False + types=frozendict({EventTypes.Member: {e5.state_key}}), + include_others=False, ), ) |