diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index b55dd07f14..2f6499966c 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, List, Set, Tuple
+from typing import Dict, List, Set, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
@@ -421,41 +421,53 @@ class EventChainStoreTestCase(HomeserverTestCase):
self, events: List[EventBase]
) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
# Fetch the map from event ID -> (chain ID, sequence number)
- rows = self.get_success(
- self.store.db_pool.simple_select_many_batch(
- table="event_auth_chains",
- column="event_id",
- iterable=[e.event_id for e in events],
- retcols=("event_id", "chain_id", "sequence_number"),
- keyvalues={},
- )
+ rows = cast(
+ List[Tuple[str, int, int]],
+ self.get_success(
+ self.store.db_pool.simple_select_many_batch(
+ table="event_auth_chains",
+ column="event_id",
+ iterable=[e.event_id for e in events],
+ retcols=("event_id", "chain_id", "sequence_number"),
+ keyvalues={},
+ )
+ ),
)
chain_map = {
- row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
+ event_id: (chain_id, sequence_number)
+ for event_id, chain_id, sequence_number in rows
}
# Fetch all the links and pass them to the _LinkMap.
- rows = self.get_success(
- self.store.db_pool.simple_select_many_batch(
- table="event_auth_chain_links",
- column="origin_chain_id",
- iterable=[chain_id for chain_id, _ in chain_map.values()],
- retcols=(
- "origin_chain_id",
- "origin_sequence_number",
- "target_chain_id",
- "target_sequence_number",
- ),
- keyvalues={},
- )
+ auth_chain_rows = cast(
+ List[Tuple[int, int, int, int]],
+ self.get_success(
+ self.store.db_pool.simple_select_many_batch(
+ table="event_auth_chain_links",
+ column="origin_chain_id",
+ iterable=[chain_id for chain_id, _ in chain_map.values()],
+ retcols=(
+ "origin_chain_id",
+ "origin_sequence_number",
+ "target_chain_id",
+ "target_sequence_number",
+ ),
+ keyvalues={},
+ )
+ ),
)
link_map = _LinkMap()
- for row in rows:
+ for (
+ origin_chain_id,
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in auth_chain_rows:
added = link_map.add_link(
- (row["origin_chain_id"], row["origin_sequence_number"]),
- (row["target_chain_id"], row["target_sequence_number"]),
+ (origin_chain_id, origin_sequence_number),
+ (target_chain_id, target_sequence_number),
)
# We shouldn't have persisted any redundant links
|