diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 0b9446c36c..2715c73f16 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -13,6 +13,7 @@
# limitations under the License.
import logging
+from typing import List, Tuple, cast
from immutabledict import immutabledict
@@ -584,18 +585,21 @@ class StateStoreTestCase(HomeserverTestCase):
)
# check that only state events are in state_groups, and all state events are in state_groups
- res = self.get_success(
- self.store.db_pool.simple_select_list(
- table="state_groups",
- keyvalues=None,
- retcols=("event_id",),
- )
+ res = cast(
+ List[Tuple[str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_groups",
+ keyvalues=None,
+ retcols=("event_id",),
+ )
+ ),
)
events = []
for result in res:
- self.assertNotIn(event3.event_id, result)
- events.append(result.get("event_id"))
+ self.assertNotIn(event3.event_id, result) # XXX
+ events.append(result[0])
for event, _ in processed_events_and_context:
if event.is_state():
@@ -606,23 +610,29 @@ class StateStoreTestCase(HomeserverTestCase):
# has an entry and prev event in state_group_edges
for event, context in processed_events_and_context:
if event.is_state():
- state = self.get_success(
- self.store.db_pool.simple_select_list(
- table="state_groups_state",
- keyvalues={"state_group": context.state_group_after_event},
- retcols=("type", "state_key"),
- )
- )
- self.assertEqual(event.type, state[0].get("type"))
- self.assertEqual(event.state_key, state[0].get("state_key"))
-
- groups = self.get_success(
- self.store.db_pool.simple_select_list(
- table="state_group_edges",
- keyvalues={"state_group": str(context.state_group_after_event)},
- retcols=("*",),
- )
+ state = cast(
+ List[Tuple[str, str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_groups_state",
+ keyvalues={"state_group": context.state_group_after_event},
+ retcols=("type", "state_key"),
+ )
+ ),
)
- self.assertEqual(
- context.state_group_before_event, groups[0].get("prev_state_group")
+ self.assertEqual(event.type, state[0][0])
+ self.assertEqual(event.state_key, state[0][1])
+
+ groups = cast(
+ List[Tuple[str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_group_edges",
+ keyvalues={
+ "state_group": str(context.state_group_after_event)
+ },
+ retcols=("prev_state_group",),
+ )
+ ),
)
+ self.assertEqual(context.state_group_before_event, groups[0][0])
|