summary refs log tree commit diff
diff options
context:
space:
mode:
authorOlivier Wilkinson (reivilibre) <olivier@librepush.net>2021-08-04 15:01:17 +0100
committerOlivier Wilkinson (reivilibre) <olivier@librepush.net>2021-08-04 15:06:06 +0100
commit5fa9110c24138712029335e77434836c2c4c25da (patch)
tree3a2cb5d57c4ecc6da0a33c9eca0d5ab9c3f0dbef
parentRemove _get_state_groups_from_groups_txn (diff)
downloadsynapse-5fa9110c24138712029335e77434836c2c4c25da.tar.xz
Make StateFilter frozen
-rw-r--r--synapse/storage/databases/state/store.py15
-rw-r--r--synapse/storage/state.py27
-rw-r--r--tests/storage/test_state.py43
3 files changed, 44 insertions, 41 deletions
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 7119323ed4..e4b47ff8e0 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -16,8 +16,6 @@ import logging
 from collections import namedtuple
 from typing import Dict, Iterable, List, Optional, Set, Tuple
 
-from frozendict import frozendict
-
 from synapse.api.constants import EventTypes
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
@@ -188,19 +186,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             state map
         """
 
-        # convert the state_filter.types dict into something that is hashable.
-        frozen_kvs = {}
-        for k, v in state_filter.types.items():
-            if v is None:
-                frozen_kvs[k] = v
-            else:
-                # make the set hashable by making a frozen copy of it
-                frozen_kvs[k] = frozenset(v)
-
-        state_filter_hashable = (frozendict(frozen_kvs), state_filter.include_others)
-
         return await self._state_group_from_group_cache.wrap(
-            (group, state_filter_hashable),
+            (group, state_filter),
             self.db_pool.runInteraction,
             "_get_state_groups_from_group",
             self._get_state_groups_from_group_txn,
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index e5400d681a..f23082f1df 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -25,6 +25,7 @@ from typing import (
 )
 
 import attr
+from frozendict import frozendict
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
@@ -40,7 +41,7 @@ logger = logging.getLogger(__name__)
 T = TypeVar("T")
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, frozen=True)
 class StateFilter:
     """A filter used when querying for state.
 
@@ -53,14 +54,16 @@ class StateFilter:
             appear in `types`.
     """
 
-    types = attr.ib(type=Dict[str, Optional[Set[str]]])
+    types = attr.ib(type=frozendict[str, Optional[Set[str]]])
     include_others = attr.ib(default=False, type=bool)
 
     def __attrs_post_init__(self):
         # If `include_others` is set we canonicalise the filter by removing
         # wildcards from the types dictionary
         if self.include_others:
-            self.types = {k: v for k, v in self.types.items() if v is not None}
+            self.types = frozendict(
+                {k: v for k, v in self.types.items() if v is not None}
+            )
 
     @staticmethod
     def all() -> "StateFilter":
@@ -69,7 +72,7 @@ class StateFilter:
         Returns:
             The new state filter.
         """
-        return StateFilter(types={}, include_others=True)
+        return StateFilter(types=frozendict(), include_others=True)
 
     @staticmethod
     def none() -> "StateFilter":
@@ -78,7 +81,7 @@ class StateFilter:
         Returns:
             The new state filter.
         """
-        return StateFilter(types={}, include_others=False)
+        return StateFilter(types=frozendict(), include_others=False)
 
     @staticmethod
     def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
@@ -103,7 +106,7 @@ class StateFilter:
 
             type_dict.setdefault(typ, set()).add(s)  # type: ignore
 
-        return StateFilter(types=type_dict)
+        return StateFilter(types=frozendict(type_dict))
 
     @staticmethod
     def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
@@ -116,7 +119,9 @@ class StateFilter:
         Returns:
             The new state filter
         """
-        return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
+        return StateFilter(
+            types=frozendict({EventTypes.Member: set(members)}), include_others=True
+        )
 
     def return_expanded(self) -> "StateFilter":
         """Creates a new StateFilter where type wild cards have been removed
@@ -173,7 +178,7 @@ class StateFilter:
             # We want to return all non-members, but only particular
             # memberships
             return StateFilter(
-                types={EventTypes.Member: self.types[EventTypes.Member]},
+                types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
                 include_others=True,
             )
 
@@ -324,14 +329,16 @@ class StateFilter:
             if state_keys is None:
                 member_filter = StateFilter.all()
             else:
-                member_filter = StateFilter({EventTypes.Member: state_keys})
+                member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
         elif self.include_others:
             member_filter = StateFilter.all()
         else:
             member_filter = StateFilter.none()
 
         non_member_filter = StateFilter(
-            types={k: v for k, v in self.types.items() if k != EventTypes.Member},
+            types=frozendict(
+                {k: v for k, v in self.types.items() if k != EventTypes.Member}
+            ),
             include_others=self.include_others,
         )
 
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 8695264595..d5e9e850a9 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -14,6 +14,8 @@
 
 import logging
 
+from frozendict import frozendict
+
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.storage.state import StateFilter
@@ -183,7 +185,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.storage.state.get_state_for_event(
                 e5.event_id,
                 state_filter=StateFilter(
-                    types={EventTypes.Member: {self.u_alice.to_string()}},
+                    types=frozendict({EventTypes.Member: {self.u_alice.to_string()}}),
                     include_others=True,
                 ),
             )
@@ -203,7 +205,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.storage.state.get_state_for_event(
                 e5.event_id,
                 state_filter=StateFilter(
-                    types={EventTypes.Member: set()}, include_others=True
+                    types=frozendict({EventTypes.Member: set()}), include_others=True
                 ),
             )
         )
@@ -228,7 +230,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
+                types=frozendict({EventTypes.Member: set()}), include_others=True
             ),
         )
 
@@ -245,7 +247,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
+                types=frozendict({EventTypes.Member: set()}), include_others=True
             ),
         )
 
@@ -258,7 +260,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: None}, include_others=True
+                types=frozendict({EventTypes.Member: None}), include_others=True
             ),
         )
 
@@ -275,7 +277,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: None}, include_others=True
+                types=frozendict({EventTypes.Member: None}), include_others=True
             ),
         )
 
@@ -295,7 +297,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=True
+                types=frozendict({EventTypes.Member: {e5.state_key}}),
+                include_others=True,
             ),
         )
 
@@ -312,7 +315,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=True
+                types=frozendict({EventTypes.Member: {e5.state_key}}),
+                include_others=True,
             ),
         )
 
@@ -325,7 +329,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=False
+                types=frozendict({EventTypes.Member: {e5.state_key}}),
+                include_others=False,
             ),
         )
 
@@ -375,7 +380,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
+                types=frozendict({EventTypes.Member: set()}), include_others=True
             ),
         )
 
@@ -387,7 +392,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
+                types=frozendict({EventTypes.Member: set()}), include_others=True
             ),
         )
 
@@ -400,7 +405,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: None}, include_others=True
+                types=frozendict({EventTypes.Member: None}), include_others=True
             ),
         )
 
@@ -411,7 +416,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: None}, include_others=True
+                types=frozendict({EventTypes.Member: None}), include_others=True
             ),
         )
 
@@ -430,7 +435,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=True
+                types=frozendict({EventTypes.Member: {e5.state_key}}),
+                include_others=True,
             ),
         )
 
@@ -441,7 +447,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=True
+                types=frozendict({EventTypes.Member: {e5.state_key}}),
+                include_others=True,
             ),
         )
 
@@ -454,7 +461,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=False
+                types=frozendict({EventTypes.Member: {e5.state_key}}),
+                include_others=False,
             ),
         )
 
@@ -465,7 +473,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=False
+                types=frozendict({EventTypes.Member: {e5.state_key}}),
+                include_others=False,
             ),
         )