diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 4bbcf7199c..5a1a3e8e65 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -245,33 +245,74 @@ class DeviceInboxWorkerStore(SQLBaseStore):
* The last-processed stream ID. Subsequent calls of this function with the
same device should pass this value as 'from_stream_id'.
"""
- (
- user_id_device_id_to_messages,
- last_processed_stream_id,
- ) = await self._get_device_messages(
- user_ids=[user_id],
- device_id=device_id,
- from_stream_id=from_stream_id,
- to_stream_id=to_stream_id,
- limit=limit,
- )
-
- if not user_id_device_id_to_messages:
+ if not self._device_inbox_stream_cache.has_entity_changed(
+ user_id, from_stream_id
+ ):
# There were no messages!
return [], to_stream_id
- # Extract the messages, no need to return the user and device ID again
- to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), [])
+ def get_device_messages_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], int]:
+ sql = """
+ SELECT stream_id, message_json FROM device_inbox
+ WHERE user_id = ? AND device_id = ?
+ AND ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (user_id, device_id, from_stream_id, to_stream_id, limit))
+
+ # Create and fill a dictionary of (user ID, device ID) -> list of messages
+ # intended for each device.
+ last_processed_stream_pos = to_stream_id
+ to_device_messages: List[JsonDict] = []
+ rowcount = 0
+ for row in txn:
+ rowcount += 1
+
+ last_processed_stream_pos = row[0]
+ message_dict = db_to_json(row[1])
+
+ # Store the device details
+ to_device_messages.append(message_dict)
- return to_device_messages, last_processed_stream_id
+ # start a new span for each message, so that we can tag each separately
+ with start_active_span("get_to_device_message"):
+ set_tag(SynapseTags.TO_DEVICE_TYPE, message_dict["type"])
+ set_tag(SynapseTags.TO_DEVICE_SENDER, message_dict["sender"])
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id)
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id)
+ set_tag(
+ SynapseTags.TO_DEVICE_MSGID,
+ message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID),
+ )
+
+ if rowcount == limit:
+ # We ended up bumping up against the message limit. There may be more messages
+ # to retrieve. Return what we have, as well as the last stream position that
+ # was processed.
+ #
+ # The caller is expected to set this as the lower (exclusive) bound
+ # for the next query of this device.
+ return to_device_messages, last_processed_stream_pos
+
+ # The limit was not reached, thus we know that recipient_device_to_messages
+ # contains all to-device messages for the given device and stream id range.
+ #
+ # We return to_stream_id, which the caller should then provide as the lower
+ # (exclusive) bound on the next query of this device.
+ return to_device_messages, to_stream_id
+
+ return await self.db_pool.runInteraction(
+ "get_messages_for_device", get_device_messages_txn
+ )
async def _get_device_messages(
self,
user_ids: Collection[str],
from_stream_id: int,
to_stream_id: int,
- device_id: Optional[str] = None,
- limit: Optional[int] = None,
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
"""
Retrieve pending to-device messages for a collection of user devices.
@@ -291,11 +332,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
user_ids: The user IDs to filter device messages by.
from_stream_id: The lower boundary of stream id to filter with (exclusive).
to_stream_id: The upper boundary of stream id to filter with (inclusive).
- device_id: A device ID to query to-device messages for. If not provided, to-device
- messages from all device IDs for the given user IDs will be queried. May not be
- provided if `user_ids` contains more than one entry.
- limit: The maximum number of to-device messages to return. Can only be used when
- passing a single user ID / device ID tuple.
+
Returns:
A tuple containing:
@@ -308,30 +345,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
logger.warning("No users provided upon querying for device IDs")
return {}, to_stream_id
- # Prevent a query for one user's device also retrieving another user's device with
- # the same device ID (device IDs are not unique across users).
- if len(user_ids) > 1 and device_id is not None:
- raise AssertionError(
- "Programming error: 'device_id' cannot be supplied to "
- "_get_device_messages when >1 user_id has been provided"
- )
-
- # A limit can only be applied when querying for a single user ID / device ID tuple.
- # See the docstring of this function for more details.
- if limit is not None and device_id is None:
- raise AssertionError(
- "Programming error: _get_device_messages was passed 'limit' "
- "without a specific user_id/device_id"
- )
-
user_ids_to_query: Set[str] = set()
- device_ids_to_query: Set[str] = set()
-
- # Note that a device ID could be an empty str
- if device_id is not None:
- # If a device ID was passed, use it to filter results.
- # Otherwise, device IDs will be derived from the given collection of user IDs.
- device_ids_to_query.add(device_id)
# Determine which users have devices with pending messages
for user_id in user_ids:
@@ -355,20 +369,20 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# hidden devices should not receive to-device messages.
# Note that this is more efficient than just dropping `device_id` from the query,
# since device_inbox has an index on `(user_id, device_id, stream_id)`
- if not device_ids_to_query:
- user_device_dicts = cast(
- List[Tuple[str]],
- self.db_pool.simple_select_many_txn(
- txn,
- table="devices",
- column="user_id",
- iterable=user_ids_to_query,
- keyvalues={"hidden": False},
- retcols=("device_id",),
- ),
- )
- device_ids_to_query.update({row[0] for row in user_device_dicts})
+ user_device_dicts = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="devices",
+ column="user_id",
+ iterable=user_ids_to_query,
+ keyvalues={"hidden": False},
+ retcols=("device_id",),
+ ),
+ )
+
+ device_ids_to_query = {row[0] for row in user_device_dicts}
if not device_ids_to_query:
# We've ended up with no devices to query.
@@ -400,22 +414,15 @@ class DeviceInboxWorkerStore(SQLBaseStore):
to_stream_id,
)
- # If a limit was provided, limit the data retrieved from the database
- if limit is not None:
- sql += "LIMIT ?"
- sql_args += (limit,)
-
txn.execute(sql, sql_args)
# Create and fill a dictionary of (user ID, device ID) -> list of messages
# intended for each device.
- last_processed_stream_pos = to_stream_id
recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
rowcount = 0
for row in txn:
rowcount += 1
- last_processed_stream_pos = row[0]
recipient_user_id = row[1]
recipient_device_id = row[2]
message_dict = db_to_json(row[3])
@@ -436,18 +443,6 @@ class DeviceInboxWorkerStore(SQLBaseStore):
message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID),
)
- if limit is not None and rowcount == limit:
- # We ended up bumping up against the message limit. There may be more messages
- # to retrieve. Return what we have, as well as the last stream position that
- # was processed.
- #
- # The caller is expected to set this as the lower (exclusive) bound
- # for the next query of this device.
- return recipient_device_to_messages, last_processed_stream_pos
-
- # The limit was not reached, thus we know that recipient_device_to_messages
- # contains all to-device messages for the given device and stream id range.
- #
# We return to_stream_id, which the caller should then provide as the lower
# (exclusive) bound on the next query of this device.
return recipient_device_to_messages, to_stream_id
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 38029710db..d3859014b6 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1796,7 +1796,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_ids: The IDs of the devices to delete
"""
- def _delete_devices_txn(txn: LoggingTransaction) -> None:
+ def _delete_devices_txn(txn: LoggingTransaction, device_ids: List[str]) -> None:
self.db_pool.simple_delete_many_txn(
txn,
table="devices",
@@ -1813,7 +1813,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
keyvalues={"user_id": user_id},
)
- await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
+ for batch in batch_iter(device_ids, 100):
+ await self.db_pool.runInteraction(
+ "delete_devices", _delete_devices_txn, batch
+ )
+
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 650b8c8135..6d4e2942ea 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -357,10 +357,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
This function is intentionally not cached because it is called to calculate the
unread badge for push notifications and thus the result is expected to change.
- Note that this function assumes the user is a member of the room. Because
- summary rows are not removed when a user leaves a room, the caller must
- filter out those results from the result.
-
Returns:
A map of room ID to notification counts for the given user.
"""
@@ -373,127 +369,170 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def _get_unread_counts_by_room_for_user_txn(
self, txn: LoggingTransaction, user_id: str
) -> Dict[str, int]:
- receipt_types_clause, args = make_in_list_sql_clause(
+ # To get the badge count of all rooms we need to make three queries:
+ # 1. Fetch all counts from `event_push_summary`, discarding any stale
+ # rooms.
+ # 2. Fetch all notifications from `event_push_actions` that haven't
+ # been rotated yet.
+ # 3. Fetch all notifications from `event_push_actions` for the stale
+ # rooms.
+ #
+ # The "stale room" scenario generally happens when there is a new read
+ # receipt that hasn't yet been processed to update the
+ # `event_push_summary` table. When that happens we ignore the
+ # `event_push_summary` table for that room and calculate the count
+ # manually from `event_push_actions`.
+
+ # We need to only take into account read receipts of these types.
+ receipt_types_clause, receipt_types_args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)
- args.extend([user_id, user_id])
-
- receipts_cte = f"""
- WITH all_receipts AS (
- SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering
- FROM receipts_linearized
- LEFT JOIN events USING (room_id, event_id)
- WHERE
- {receipt_types_clause}
- AND user_id = ?
- GROUP BY room_id, thread_id
- )
- """
-
- receipts_joins = """
- LEFT JOIN (
- SELECT room_id, thread_id,
- max_receipt_stream_ordering AS threaded_receipt_stream_ordering
- FROM all_receipts
- WHERE thread_id IS NOT NULL
- ) AS threaded_receipts USING (room_id, thread_id)
- LEFT JOIN (
- SELECT room_id, thread_id,
- max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering
- FROM all_receipts
- WHERE thread_id IS NULL
- ) AS unthreaded_receipts USING (room_id)
- """
-
- # First get summary counts by room / thread for the user. We use the max receipt
- # stream ordering of both threaded & unthreaded receipts to compare against the
- # summary table.
- #
- # PostgreSQL and SQLite differ in comparing scalar numerics.
- if isinstance(self.database_engine, PostgresEngine):
- # GREATEST ignores NULLs.
- max_clause = """GREATEST(
- threaded_receipt_stream_ordering,
- unthreaded_receipt_stream_ordering
- )"""
- else:
- # MAX returns NULL if any are NULL, so COALESCE to 0 first.
- max_clause = """MAX(
- COALESCE(threaded_receipt_stream_ordering, 0),
- COALESCE(unthreaded_receipt_stream_ordering, 0)
- )"""
+ # Step 1, fetch all counts from `event_push_summary` for the user. This
+ # is slightly convoluted as we also need to pull out the stream ordering
+ # of the most recent receipt of the user in the room (either a thread
+ # aware receipt or thread unaware receipt) in order to determine
+ # whether the row in `event_push_summary` is stale. Hence the outer
+ # GROUP BY and odd join condition against `receipts_linearized`.
sql = f"""
- {receipts_cte}
- SELECT eps.room_id, eps.thread_id, notif_count
- FROM event_push_summary AS eps
- {receipts_joins}
- WHERE user_id = ?
- AND notif_count != 0
- AND (
- (last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause})
- OR last_receipt_stream_ordering = {max_clause}
+ SELECT room_id, notif_count, stream_ordering, thread_id, last_receipt_stream_ordering,
+ MAX(receipt_stream_ordering)
+ FROM (
+ SELECT e.room_id, notif_count, e.stream_ordering, e.thread_id, last_receipt_stream_ordering,
+ ev.stream_ordering AS receipt_stream_ordering
+ FROM event_push_summary AS e
+ INNER JOIN local_current_membership USING (user_id, room_id)
+ LEFT JOIN receipts_linearized AS r ON (
+ e.user_id = r.user_id
+ AND e.room_id = r.room_id
+ AND (e.thread_id = r.thread_id OR r.thread_id IS NULL)
+ AND {receipt_types_clause}
)
+ LEFT JOIN events AS ev ON (r.event_id = ev.event_id)
+ WHERE e.user_id = ? and notif_count > 0
+ ) AS es
+ GROUP BY room_id, notif_count, stream_ordering, thread_id, last_receipt_stream_ordering
"""
- txn.execute(sql, args)
-
- seen_thread_ids = set()
- room_to_count: Dict[str, int] = defaultdict(int)
- for room_id, thread_id, notif_count in txn:
- room_to_count[room_id] += notif_count
- seen_thread_ids.add(thread_id)
+ txn.execute(
+ sql,
+ receipt_types_args
+ + [
+ user_id,
+ ],
+ )
- # Now get any event push actions that haven't been rotated using the same OR
- # join and filter by receipt and event push summary rotated up to stream ordering.
- sql = f"""
- {receipts_cte}
- SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
- FROM event_push_actions AS epa
- {receipts_joins}
- WHERE user_id = ?
- AND epa.notif = 1
- AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering)
- AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
- AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
- GROUP BY epa.room_id, epa.thread_id
- """
- txn.execute(sql, args)
+ room_to_count: Dict[str, int] = defaultdict(int)
+ stale_room_ids = set()
+ for row in txn:
+ room_id = row[0]
+ notif_count = row[1]
+ stream_ordering = row[2]
+ _thread_id = row[3]
+ last_receipt_stream_ordering = row[4]
+ receipt_stream_ordering = row[5]
+
+ if last_receipt_stream_ordering is None:
+ if receipt_stream_ordering is None:
+ room_to_count[room_id] += notif_count
+ elif stream_ordering > receipt_stream_ordering:
+ room_to_count[room_id] += notif_count
+ else:
+ # The latest read receipt from the user is after all the rows for
+ # this room in `event_push_summary`. We ignore them, and
+ # calculate the count from `event_push_actions` in step 3.
+ pass
+ elif last_receipt_stream_ordering == receipt_stream_ordering:
+ room_to_count[room_id] += notif_count
+ else:
+ # The row is stale if `last_receipt_stream_ordering` is set and
+ # *doesn't* match the latest receipt from the user.
+ stale_room_ids.add(room_id)
- for room_id, thread_id, notif_count in txn:
- # Note: only count push actions we have valid summaries for with up to date receipt.
- if thread_id not in seen_thread_ids:
- continue
- room_to_count[room_id] += notif_count
+ # Discard any stale rooms from `room_to_count`, as we will recalculate
+ # them in step 3.
+ for room_id in stale_room_ids:
+ room_to_count.pop(room_id, None)
- thread_id_clause, thread_ids_args = make_in_list_sql_clause(
- self.database_engine, "epa.thread_id", seen_thread_ids
+ # Step 2, basically the same query, except against `event_push_actions`
+ # and only fetching rows inserted since the last rotation.
+ rotated_upto_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="event_push_summary_stream_ordering",
+ keyvalues={},
+ retcol="stream_ordering",
)
- # Finally re-check event_push_actions for any rooms not in the summary, ignoring
- # the rotated up-to position. This handles the case where a read receipt has arrived
- # but not been rotated meaning the summary table is out of date, so we go back to
- # the push actions table.
sql = f"""
- {receipts_cte}
- SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
- FROM event_push_actions AS epa
- {receipts_joins}
- WHERE user_id = ?
- AND NOT {thread_id_clause}
- AND epa.notif = 1
- AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
- AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
- GROUP BY epa.room_id
+ SELECT room_id, thread_id
+ FROM (
+ SELECT e.room_id, e.stream_ordering, e.thread_id,
+ ev.stream_ordering AS receipt_stream_ordering
+ FROM event_push_actions AS e
+ INNER JOIN local_current_membership USING (user_id, room_id)
+ LEFT JOIN receipts_linearized AS r ON (
+ e.user_id = r.user_id
+ AND e.room_id = r.room_id
+ AND (e.thread_id = r.thread_id OR r.thread_id IS NULL)
+ AND {receipt_types_clause}
+ )
+ LEFT JOIN events AS ev ON (r.event_id = ev.event_id)
+ WHERE e.user_id = ? and notif > 0
+ AND e.stream_ordering > ?
+ ) AS es
+ GROUP BY room_id, stream_ordering, thread_id
+ HAVING stream_ordering > COALESCE(MAX(receipt_stream_ordering), 0)
"""
- args.extend(thread_ids_args)
- txn.execute(sql, args)
+ txn.execute(
+ sql,
+ receipt_types_args + [user_id, rotated_upto_stream_ordering],
+ )
+ for room_id, _thread_id in txn:
+ # Again, we ignore any stale rooms.
+ if room_id not in stale_room_ids:
+ # For event push actions it is one notification per row.
+ room_to_count[room_id] += 1
+
+ # Step 3, if we have stale rooms then we need to recalculate the counts
+ # from `event_push_actions`. Again, this is basically the same query as
+ # above except without a lower bound on stream ordering and only against
+ # a specific set of rooms.
+ if stale_room_ids:
+ room_id_clause, room_id_args = make_in_list_sql_clause(
+ self.database_engine,
+ "e.room_id",
+ stale_room_ids,
+ )
- for room_id, notif_count in txn:
- room_to_count[room_id] += notif_count
+ sql = f"""
+ SELECT room_id, thread_id
+ FROM (
+ SELECT e.room_id, e.stream_ordering, e.thread_id,
+ ev.stream_ordering AS receipt_stream_ordering
+ FROM event_push_actions AS e
+ INNER JOIN local_current_membership USING (user_id, room_id)
+ LEFT JOIN receipts_linearized AS r ON (
+ e.user_id = r.user_id
+ AND e.room_id = r.room_id
+ AND (e.thread_id = r.thread_id OR r.thread_id IS NULL)
+ AND {receipt_types_clause}
+ )
+ LEFT JOIN events AS ev ON (r.event_id = ev.event_id)
+ WHERE e.user_id = ? and notif > 0
+ AND {room_id_clause}
+ ) AS es
+ GROUP BY room_id, stream_ordering, thread_id
+ HAVING stream_ordering > COALESCE(MAX(receipt_stream_ordering), 0)
+ """
+ txn.execute(
+ sql,
+ receipt_types_args + [user_id] + room_id_args,
+ )
+ for room_id, _ in txn:
+ room_to_count[room_id] += 1
return room_to_count
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 4700e74ad2..8006046453 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -24,13 +24,17 @@ from typing import (
Any,
Collection,
Dict,
+ FrozenSet,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
+ TypeVar,
+ Union,
cast,
+ overload,
)
import attr
@@ -52,7 +56,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
-from synapse.types import JsonDict, JsonMapping, StateMap
+from synapse.types import JsonDict, JsonMapping, StateKey, StateMap, StrCollection
from synapse.types.state import StateFilter
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@@ -64,6 +68,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+_T = TypeVar("_T")
+
MAX_STATE_DELTA_HOPS = 100
@@ -318,6 +324,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_partial_current_state_ids", _get_current_state_ids_txn
)
+ async def check_if_events_in_current_state(
+ self, event_ids: StrCollection
+ ) -> FrozenSet[str]:
+ """Checks and returns which of the given events is part of the current state."""
+ rows = await self.db_pool.simple_select_many_batch(
+ table="current_state_events",
+ column="event_id",
+ iterable=event_ids,
+ retcols=("event_id",),
+ desc="check_if_events_in_current_state",
+ )
+
+ return frozenset(event_id for event_id, in rows)
+
# FIXME: how should this be cached?
@cancellable
async def get_partial_filtered_current_state_ids(
@@ -349,7 +369,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
) -> StateMap[str]:
- results = {}
+ results = StateMapWrapper(state_filter=state_filter or StateFilter.all())
+
sql = """
SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ?
@@ -726,3 +747,41 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
+
+
+@attr.s(auto_attribs=True, slots=True)
+class StateMapWrapper(Dict[StateKey, str]):
+ """A wrapper around a StateMap[str] to ensure that we only query for items
+ that were not filtered out.
+
+ This is to help prevent bugs where we filter out state but other bits of the
+ code expect the state to be there.
+ """
+
+ state_filter: StateFilter
+
+ def __getitem__(self, key: StateKey) -> str:
+ if key not in self.state_filter:
+ raise Exception("State map was filtered and doesn't include: %s", key)
+ return super().__getitem__(key)
+
+ @overload
+ def get(self, key: Tuple[str, str]) -> Optional[str]:
+ ...
+
+ @overload
+ def get(self, key: Tuple[str, str], default: Union[str, _T]) -> Union[str, _T]:
+ ...
+
+ def get(
+ self, key: StateKey, default: Union[str, _T, None] = None
+ ) -> Union[str, _T, None]:
+ if key not in self.state_filter:
+ raise Exception("State map was filtered and doesn't include: %s", key)
+ return super().get(key, default)
+
+ def __contains__(self, key: Any) -> bool:
+ if key not in self.state_filter:
+ raise Exception("State map was filtered and doesn't include: %s", key)
+
+ return super().__contains__(key)
|