diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index e5400d681a..c76529cb57 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -25,12 +25,15 @@ from typing import (
)
import attr
+from frozendict import frozendict
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
if TYPE_CHECKING:
+ from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
+
from synapse.server import HomeServer
from synapse.storage.databases import Databases
@@ -40,7 +43,7 @@ logger = logging.getLogger(__name__)
T = TypeVar("T")
-@attr.s(slots=True)
+@attr.s(slots=True, frozen=True)
class StateFilter:
"""A filter used when querying for state.
@@ -53,14 +56,19 @@ class StateFilter:
appear in `types`.
"""
- types = attr.ib(type=Dict[str, Optional[Set[str]]])
+ types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
include_others = attr.ib(default=False, type=bool)
def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
- self.types = {k: v for k, v in self.types.items() if v is not None}
+ # this is needed to work around the fact that StateFilter is frozen
+ object.__setattr__(
+ self,
+ "types",
+ frozendict({k: v for k, v in self.types.items() if v is not None}),
+ )
@staticmethod
def all() -> "StateFilter":
@@ -69,7 +77,7 @@ class StateFilter:
Returns:
The new state filter.
"""
- return StateFilter(types={}, include_others=True)
+ return StateFilter(types=frozendict(), include_others=True)
@staticmethod
def none() -> "StateFilter":
@@ -78,7 +86,7 @@ class StateFilter:
Returns:
The new state filter.
"""
- return StateFilter(types={}, include_others=False)
+ return StateFilter(types=frozendict(), include_others=False)
@staticmethod
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
@@ -103,7 +111,12 @@ class StateFilter:
type_dict.setdefault(typ, set()).add(s) # type: ignore
- return StateFilter(types=type_dict)
+ return StateFilter(
+ types=frozendict(
+ (k, frozenset(v) if v is not None else None)
+ for k, v in type_dict.items()
+ )
+ )
@staticmethod
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
@@ -116,7 +129,10 @@ class StateFilter:
Returns:
The new state filter
"""
- return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
+ return StateFilter(
+ types=frozendict({EventTypes.Member: frozenset(members)}),
+ include_others=True,
+ )
def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed
@@ -173,7 +189,7 @@ class StateFilter:
# We want to return all non-members, but only particular
# memberships
return StateFilter(
- types={EventTypes.Member: self.types[EventTypes.Member]},
+ types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
include_others=True,
)
@@ -245,14 +261,15 @@ class StateFilter:
return len(self.concrete_types())
- def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
- """Returns the state filtered with by this StateFilter
+ def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]:
+ """Returns the state filtered with by this StateFilter.
Args:
state: The state map to filter
Returns:
- The filtered state map
+ The filtered state map.
+ This is a copy, so it's safe to mutate.
"""
if self.is_full():
return dict(state_dict)
@@ -324,14 +341,16 @@ class StateFilter:
if state_keys is None:
member_filter = StateFilter.all()
else:
- member_filter = StateFilter({EventTypes.Member: state_keys})
+ member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
elif self.include_others:
member_filter = StateFilter.all()
else:
member_filter = StateFilter.none()
non_member_filter = StateFilter(
- types={k: v for k, v in self.types.items() if k != EventTypes.Member},
+ types=frozendict(
+ {k: v for k, v in self.types.items() if k != EventTypes.Member}
+ ),
include_others=self.include_others,
)
|