summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/storage/test_event_chain.py64
1 files changed, 38 insertions, 26 deletions
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index b55dd07f14..2f6499966c 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Dict, List, Set, Tuple
+from typing import Dict, List, Set, Tuple, cast
 
 from twisted.test.proto_helpers import MemoryReactor
 from twisted.trial import unittest
@@ -421,41 +421,53 @@ class EventChainStoreTestCase(HomeserverTestCase):
         self, events: List[EventBase]
     ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
         # Fetch the map from event ID -> (chain ID, sequence number)
-        rows = self.get_success(
-            self.store.db_pool.simple_select_many_batch(
-                table="event_auth_chains",
-                column="event_id",
-                iterable=[e.event_id for e in events],
-                retcols=("event_id", "chain_id", "sequence_number"),
-                keyvalues={},
-            )
+        rows = cast(
+            List[Tuple[str, int, int]],
+            self.get_success(
+                self.store.db_pool.simple_select_many_batch(
+                    table="event_auth_chains",
+                    column="event_id",
+                    iterable=[e.event_id for e in events],
+                    retcols=("event_id", "chain_id", "sequence_number"),
+                    keyvalues={},
+                )
+            ),
         )
 
         chain_map = {
-            row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
+            event_id: (chain_id, sequence_number)
+            for event_id, chain_id, sequence_number in rows
         }
 
         # Fetch all the links and pass them to the _LinkMap.
-        rows = self.get_success(
-            self.store.db_pool.simple_select_many_batch(
-                table="event_auth_chain_links",
-                column="origin_chain_id",
-                iterable=[chain_id for chain_id, _ in chain_map.values()],
-                retcols=(
-                    "origin_chain_id",
-                    "origin_sequence_number",
-                    "target_chain_id",
-                    "target_sequence_number",
-                ),
-                keyvalues={},
-            )
+        auth_chain_rows = cast(
+            List[Tuple[int, int, int, int]],
+            self.get_success(
+                self.store.db_pool.simple_select_many_batch(
+                    table="event_auth_chain_links",
+                    column="origin_chain_id",
+                    iterable=[chain_id for chain_id, _ in chain_map.values()],
+                    retcols=(
+                        "origin_chain_id",
+                        "origin_sequence_number",
+                        "target_chain_id",
+                        "target_sequence_number",
+                    ),
+                    keyvalues={},
+                )
+            ),
         )
 
         link_map = _LinkMap()
-        for row in rows:
+        for (
+            origin_chain_id,
+            origin_sequence_number,
+            target_chain_id,
+            target_sequence_number,
+        ) in auth_chain_rows:
             added = link_map.add_link(
-                (row["origin_chain_id"], row["origin_sequence_number"]),
-                (row["target_chain_id"], row["target_sequence_number"]),
+                (origin_chain_id, origin_sequence_number),
+                (target_chain_id, target_sequence_number),
             )
 
             # We shouldn't have persisted any redundant links