summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
authorJonathan de Jong <jonathan@automatia.nl>2021-04-08 23:38:54 +0200
committerGitHub <noreply@github.com>2021-04-08 22:38:54 +0100
commit2ca4e349e9d0c606d802ae15c06089080fa4f27e (patch)
tree0b3e8448124e0c0cbe128b27ab1641be02697536 /synapse/storage/state.py
parentMerge pull request #9766 from matrix-org/rav/drop_py35 (diff)
downloadsynapse-2ca4e349e9d0c606d802ae15c06089080fa4f27e.tar.xz
Bugbear: Add Mutable Parameter fixes (#9682)
Part of #9366

Adds in fixes for B006 and B008, both relating to mutable parameter lint errors.

Signed-off-by: Jonathan de Jong <jonathan@automatia.nl>
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r--synapse/storage/state.py26
1 files changed, 16 insertions, 10 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 2e277a21c4..c1c147c62a 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -449,7 +449,7 @@ class StateGroupStorage:
         return self.stores.state._get_state_groups_from_groups(groups, state_filter)
 
     async def get_state_for_events(
-        self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
+        self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
     ) -> Dict[str, StateMap[EventBase]]:
         """Given a list of event_ids and type tuples, return a list of state
         dicts for each event.
@@ -465,7 +465,7 @@ class StateGroupStorage:
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(
-            groups, state_filter
+            groups, state_filter or StateFilter.all()
         )
 
         state_event_map = await self.stores.main.get_events(
@@ -485,7 +485,7 @@ class StateGroupStorage:
         return {event: event_to_state[event] for event in event_ids}
 
     async def get_state_ids_for_events(
-        self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
+        self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
     ) -> Dict[str, StateMap[str]]:
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
@@ -502,7 +502,7 @@ class StateGroupStorage:
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(
-            groups, state_filter
+            groups, state_filter or StateFilter.all()
         )
 
         event_to_state = {
@@ -513,7 +513,7 @@ class StateGroupStorage:
         return {event: event_to_state[event] for event in event_ids}
 
     async def get_state_for_event(
-        self, event_id: str, state_filter: StateFilter = StateFilter.all()
+        self, event_id: str, state_filter: Optional[StateFilter] = None
     ) -> StateMap[EventBase]:
         """
         Get the state dict corresponding to a particular event
@@ -525,11 +525,13 @@ class StateGroupStorage:
         Returns:
             A dict from (type, state_key) -> state_event
         """
-        state_map = await self.get_state_for_events([event_id], state_filter)
+        state_map = await self.get_state_for_events(
+            [event_id], state_filter or StateFilter.all()
+        )
         return state_map[event_id]
 
     async def get_state_ids_for_event(
-        self, event_id: str, state_filter: StateFilter = StateFilter.all()
+        self, event_id: str, state_filter: Optional[StateFilter] = None
     ) -> StateMap[str]:
         """
         Get the state dict corresponding to a particular event
@@ -541,11 +543,13 @@ class StateGroupStorage:
         Returns:
             A dict from (type, state_key) -> state_event
         """
-        state_map = await self.get_state_ids_for_events([event_id], state_filter)
+        state_map = await self.get_state_ids_for_events(
+            [event_id], state_filter or StateFilter.all()
+        )
         return state_map[event_id]
 
     def _get_state_for_groups(
-        self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+        self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
     ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
@@ -558,7 +562,9 @@ class StateGroupStorage:
         Returns:
             Dict of state group to state map.
         """
-        return self.stores.state._get_state_for_groups(groups, state_filter)
+        return self.stores.state._get_state_for_groups(
+            groups, state_filter or StateFilter.all()
+        )
 
     async def store_state_group(
         self,