summary refs log tree commit diff
path: root/tests/storage/test_state.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_state.py')
-rw-r--r--tests/storage/test_state.py62
1 files changed, 36 insertions, 26 deletions
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 0b9446c36c..2715c73f16 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import logging
+from typing import List, Tuple, cast
 
 from immutabledict import immutabledict
 
@@ -584,18 +585,21 @@ class StateStoreTestCase(HomeserverTestCase):
         )
 
         # check that only state events are in state_groups, and all state events are in state_groups
-        res = self.get_success(
-            self.store.db_pool.simple_select_list(
-                table="state_groups",
-                keyvalues=None,
-                retcols=("event_id",),
-            )
+        res = cast(
+            List[Tuple[str]],
+            self.get_success(
+                self.store.db_pool.simple_select_list(
+                    table="state_groups",
+                    keyvalues=None,
+                    retcols=("event_id",),
+                )
+            ),
         )
 
         events = []
         for result in res:
-            self.assertNotIn(event3.event_id, result)
-            events.append(result.get("event_id"))
+            self.assertNotIn(event3.event_id, result)  # XXX
+            events.append(result[0])
 
         for event, _ in processed_events_and_context:
             if event.is_state():
@@ -606,23 +610,29 @@ class StateStoreTestCase(HomeserverTestCase):
         # has an entry and prev event in state_group_edges
         for event, context in processed_events_and_context:
             if event.is_state():
-                state = self.get_success(
-                    self.store.db_pool.simple_select_list(
-                        table="state_groups_state",
-                        keyvalues={"state_group": context.state_group_after_event},
-                        retcols=("type", "state_key"),
-                    )
-                )
-                self.assertEqual(event.type, state[0].get("type"))
-                self.assertEqual(event.state_key, state[0].get("state_key"))
-
-                groups = self.get_success(
-                    self.store.db_pool.simple_select_list(
-                        table="state_group_edges",
-                        keyvalues={"state_group": str(context.state_group_after_event)},
-                        retcols=("*",),
-                    )
+                state = cast(
+                    List[Tuple[str, str]],
+                    self.get_success(
+                        self.store.db_pool.simple_select_list(
+                            table="state_groups_state",
+                            keyvalues={"state_group": context.state_group_after_event},
+                            retcols=("type", "state_key"),
+                        )
+                    ),
                 )
-                self.assertEqual(
-                    context.state_group_before_event, groups[0].get("prev_state_group")
+                self.assertEqual(event.type, state[0][0])
+                self.assertEqual(event.state_key, state[0][1])
+
+                groups = cast(
+                    List[Tuple[str]],
+                    self.get_success(
+                        self.store.db_pool.simple_select_list(
+                            table="state_group_edges",
+                            keyvalues={
+                                "state_group": str(context.state_group_after_event)
+                            },
+                            retcols=("prev_state_group",),
+                        )
+                    ),
                 )
+                self.assertEqual(context.state_group_before_event, groups[0][0])