diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 1fd333b707..75c09b3687 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Optional
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@@ -73,8 +74,10 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
return count
def _get_state_groups_from_groups_txn(
- self, txn, groups, state_filter=StateFilter.all()
+ self, txn, groups, state_filter: Optional[StateFilter] = None
):
+ state_filter = state_filter or StateFilter.all()
+
results = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 97ec65f757..dfcf89d91c 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -15,7 +15,7 @@
import logging
from collections import namedtuple
-from typing import Dict, Iterable, List, Set, Tuple
+from typing import Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore
@@ -210,7 +210,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types
async def _get_state_for_groups(
- self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+ self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
@@ -223,6 +223,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns:
Dict of state group to state map.
"""
+ state_filter = state_filter or StateFilter.all()
member_filter, non_member_filter = state_filter.get_member_split()
|