summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r--synapse/storage/state.py26
1 files changed, 16 insertions, 10 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 2e277a21c4..c1c147c62a 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -449,7 +449,7 @@ class StateGroupStorage:
         return self.stores.state._get_state_groups_from_groups(groups, state_filter)
 
     async def get_state_for_events(
-        self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
+        self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
     ) -> Dict[str, StateMap[EventBase]]:
         """Given a list of event_ids and type tuples, return a list of state
         dicts for each event.
@@ -465,7 +465,7 @@ class StateGroupStorage:
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(
-            groups, state_filter
+            groups, state_filter or StateFilter.all()
         )
 
         state_event_map = await self.stores.main.get_events(
@@ -485,7 +485,7 @@ class StateGroupStorage:
         return {event: event_to_state[event] for event in event_ids}
 
     async def get_state_ids_for_events(
-        self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
+        self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
     ) -> Dict[str, StateMap[str]]:
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
@@ -502,7 +502,7 @@ class StateGroupStorage:
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(
-            groups, state_filter
+            groups, state_filter or StateFilter.all()
         )
 
         event_to_state = {
@@ -513,7 +513,7 @@ class StateGroupStorage:
         return {event: event_to_state[event] for event in event_ids}
 
     async def get_state_for_event(
-        self, event_id: str, state_filter: StateFilter = StateFilter.all()
+        self, event_id: str, state_filter: Optional[StateFilter] = None
     ) -> StateMap[EventBase]:
         """
         Get the state dict corresponding to a particular event
@@ -525,11 +525,13 @@ class StateGroupStorage:
         Returns:
             A dict from (type, state_key) -> state_event
         """
-        state_map = await self.get_state_for_events([event_id], state_filter)
+        state_map = await self.get_state_for_events(
+            [event_id], state_filter or StateFilter.all()
+        )
         return state_map[event_id]
 
     async def get_state_ids_for_event(
-        self, event_id: str, state_filter: StateFilter = StateFilter.all()
+        self, event_id: str, state_filter: Optional[StateFilter] = None
     ) -> StateMap[str]:
         """
         Get the state dict corresponding to a particular event
@@ -541,11 +543,13 @@ class StateGroupStorage:
         Returns:
             A dict from (type, state_key) -> state_event
         """
-        state_map = await self.get_state_ids_for_events([event_id], state_filter)
+        state_map = await self.get_state_ids_for_events(
+            [event_id], state_filter or StateFilter.all()
+        )
         return state_map[event_id]
 
     def _get_state_for_groups(
-        self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+        self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
     ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
@@ -558,7 +562,9 @@ class StateGroupStorage:
         Returns:
             Dict of state group to state map.
         """
-        return self.stores.state._get_state_for_groups(groups, state_filter)
+        return self.stores.state._get_state_for_groups(
+            groups, state_filter or StateFilter.all()
+        )
 
     async def store_state_group(
         self,