summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-10-26 13:01:36 -0400
committerGitHub <noreply@github.com>2023-10-26 13:01:36 -0400
commit9407d5ba78d1e5275b5817ae9e6aedf7d1ca14f7 (patch)
tree70935c19b787e89115d6f8884f3d134a6bacf264 /tests
parentPin the recommended poetry version in contributors' guide (#16550) (diff)
downloadsynapse-9407d5ba78d1e5275b5817ae9e6aedf7d1ca14f7.tar.xz
Convert simple_select_list and simple_select_list_txn to return lists of tuples (#16505)
This should use fewer allocations and improves type hints.
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_stats.py14
-rw-r--r--tests/storage/databases/main/test_receipts.py20
-rw-r--r--tests/storage/test__base.py16
-rw-r--r--tests/storage/test_background_update.py35
-rw-r--r--tests/storage/test_base.py4
-rw-r--r--tests/storage/test_client_ips.py250
-rw-r--r--tests/storage/test_roommember.py40
-rw-r--r--tests/storage/test_state.py62
-rw-r--r--tests/storage/test_user_directory.py61
9 files changed, 262 insertions, 240 deletions
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index d11ded6c5b..76c56d5434 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple, cast
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -68,10 +68,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
 
-    async def get_all_room_state(self) -> List[Dict[str, Any]]:
-        return await self.store.db_pool.simple_select_list(
-            "room_stats_state", None, retcols=("name", "topic", "canonical_alias")
+    async def get_all_room_state(self) -> List[Optional[str]]:
+        rows = cast(
+            List[Tuple[Optional[str]]],
+            await self.store.db_pool.simple_select_list(
+                "room_stats_state", None, retcols=("topic",)
+            ),
         )
+        return [r[0] for r in rows]
 
     def _get_current_stats(
         self, stats_type: str, stat_id: str
@@ -130,7 +134,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         r = self.get_success(self.get_all_room_state())
 
         self.assertEqual(len(r), 1)
-        self.assertEqual(r[0]["topic"], "foo")
+        self.assertEqual(r[0], "foo")
 
     def test_create_user(self) -> None:
         """
diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py
index 71db47405e..98b01086bc 100644
--- a/tests/storage/databases/main/test_receipts.py
+++ b/tests/storage/databases/main/test_receipts.py
@@ -117,7 +117,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
             if expected_row is not None:
                 columns += expected_row.keys()
 
-            rows = self.get_success(
+            row_tuples = self.get_success(
                 self.store.db_pool.simple_select_list(
                     table=table,
                     keyvalues={
@@ -134,22 +134,22 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
 
             if expected_row is not None:
                 self.assertEqual(
-                    len(rows),
+                    len(row_tuples),
                     1,
                     f"Background update did not leave behind latest receipt in {table}",
                 )
                 self.assertEqual(
-                    rows[0],
-                    {
-                        "room_id": room_id,
-                        "receipt_type": receipt_type,
-                        "user_id": user_id,
-                        **expected_row,
-                    },
+                    row_tuples[0],
+                    (
+                        room_id,
+                        receipt_type,
+                        user_id,
+                        *expected_row.values(),
+                    ),
                 )
             else:
                 self.assertEqual(
-                    len(rows),
+                    len(row_tuples),
                     0,
                     f"Background update did not remove all duplicate receipts from {table}",
                 )
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 8bbf936ae9..8cbc974ac4 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import secrets
-from typing import Generator, Tuple
+from typing import Generator, List, Tuple, cast
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -47,15 +47,15 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase):
         )
 
     def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]:
-        res = self.get_success(
-            self.storage.db_pool.simple_select_list(
-                self.table_name, None, ["id, username, value"]
-            )
+        yield from cast(
+            List[Tuple[int, str, str]],
+            self.get_success(
+                self.storage.db_pool.simple_select_list(
+                    self.table_name, None, ["id, username, value"]
+                )
+            ),
         )
 
-        for i in res:
-            yield (i["id"], i["username"], i["value"])
-
     def test_upsert_many(self) -> None:
         """
         Upsert_many will perform the upsert operation across a batch of data.
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index abf7d0564d..3f5bfa09d4 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import List, Tuple, cast
 from unittest.mock import AsyncMock, Mock
 
 import yaml
@@ -526,15 +527,18 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
             self.wait_for_background_updates()
 
         # Check the correct values are in the new table.
-        rows = self.get_success(
-            self.store.db_pool.simple_select_list(
-                table="test_constraint",
-                keyvalues={},
-                retcols=("a", "b"),
-            )
+        rows = cast(
+            List[Tuple[int, int]],
+            self.get_success(
+                self.store.db_pool.simple_select_list(
+                    table="test_constraint",
+                    keyvalues={},
+                    retcols=("a", "b"),
+                )
+            ),
         )
 
-        self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
+        self.assertCountEqual(rows, [(1, 1), (3, 3)])
 
         # And check that invalid rows get correctly rejected.
         self.get_failure(
@@ -640,14 +644,17 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
             self.wait_for_background_updates()
 
         # Check the correct values are in the new table.
-        rows = self.get_success(
-            self.store.db_pool.simple_select_list(
-                table="test_constraint",
-                keyvalues={},
-                retcols=("a", "b"),
-            )
+        rows = cast(
+            List[Tuple[int, int]],
+            self.get_success(
+                self.store.db_pool.simple_select_list(
+                    table="test_constraint",
+                    keyvalues={},
+                    retcols=("a", "b"),
+                )
+            ),
         )
-        self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
+        self.assertCountEqual(rows, [(1, 1), (3, 3)])
 
         # And check that invalid rows get correctly rejected.
         self.get_failure(
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 256d28e4c9..e4a52c301e 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -146,7 +146,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
         self.mock_txn.rowcount = 3
-        self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
+        self.mock_txn.fetchall.return_value = [(1,), (2,), (3,)]
         self.mock_txn.description = (("colA", None, None, None, None, None, None),)
 
         ret = yield defer.ensureDeferred(
@@ -155,7 +155,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
             )
         )
 
-        self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
+        self.assertEqual([(1,), (2,), (3,)], ret)
         self.mock_txn.execute.assert_called_with(
             "SELECT colA FROM tablename WHERE keycol = ?", ["A set"]
         )
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 0c054a598f..8e4393d843 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict
+from typing import Any, Dict, List, Optional, Tuple, cast
 from unittest.mock import AsyncMock
 
 from parameterized import parameterized
@@ -97,26 +97,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(200)
         self.pump(0)
 
-        result = self.get_success(
-            self.store.db_pool.simple_select_list(
-                table="user_ips",
-                keyvalues={"user_id": user_id},
-                retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
-                desc="get_user_ip_and_agents",
-            )
+        result = cast(
+            List[Tuple[str, str, str, Optional[str], int]],
+            self.get_success(
+                self.store.db_pool.simple_select_list(
+                    table="user_ips",
+                    keyvalues={"user_id": user_id},
+                    retcols=[
+                        "access_token",
+                        "ip",
+                        "user_agent",
+                        "device_id",
+                        "last_seen",
+                    ],
+                    desc="get_user_ip_and_agents",
+                )
+            ),
         )
 
         self.assertEqual(
-            result,
-            [
-                {
-                    "access_token": "access_token",
-                    "ip": "ip",
-                    "user_agent": "user_agent",
-                    "device_id": None,
-                    "last_seen": 12345678000,
-                }
-            ],
+            result, [("access_token", "ip", "user_agent", None, 12345678000)]
         )
 
         # Add another & trigger the storage loop
@@ -128,26 +128,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(10)
         self.pump(0)
 
-        result = self.get_success(
-            self.store.db_pool.simple_select_list(
-                table="user_ips",
-                keyvalues={"user_id": user_id},
-                retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
-                desc="get_user_ip_and_agents",
-            )
+        result = cast(
+            List[Tuple[str, str, str, Optional[str], int]],
+            self.get_success(
+                self.store.db_pool.simple_select_list(
+                    table="user_ips",
+                    keyvalues={"user_id": user_id},
+                    retcols=[
+                        "access_token",
+                        "ip",
+                        "user_agent",
+                        "device_id",
+                        "last_seen",
+                    ],
+                    desc="get_user_ip_and_agents",
+                )
+            ),
         )
         # Only one result, has been upserted.
         self.assertEqual(
-            result,
-            [
-                {
-                    "access_token": "access_token",
-                    "ip": "ip",
-                    "user_agent": "user_agent",
-                    "device_id": None,
-                    "last_seen": 12345878000,
-                }
-            ],
+            result, [("access_token", "ip", "user_agent", None, 12345878000)]
         )
 
     @parameterized.expand([(False,), (True,)])
@@ -177,25 +177,23 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
             self.reactor.advance(10)
         else:
             # Check that the new IP and user agent has not been stored yet
-            db_result = self.get_success(
-                self.store.db_pool.simple_select_list(
-                    table="devices",
-                    keyvalues={},
-                    retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+            db_result = cast(
+                List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
+                self.get_success(
+                    self.store.db_pool.simple_select_list(
+                        table="devices",
+                        keyvalues={},
+                        retcols=(
+                            "user_id",
+                            "ip",
+                            "user_agent",
+                            "device_id",
+                            "last_seen",
+                        ),
+                    ),
                 ),
             )
-            self.assertEqual(
-                db_result,
-                [
-                    {
-                        "user_id": user_id,
-                        "device_id": device_id,
-                        "ip": None,
-                        "user_agent": None,
-                        "last_seen": None,
-                    },
-                ],
-            )
+            self.assertEqual(db_result, [(user_id, None, None, device_id, None)])
 
         result = self.get_success(
             self.store.get_last_client_ip_by_device(user_id, device_id)
@@ -261,30 +259,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         # Check that the new IP and user agent has not been stored yet
-        db_result = self.get_success(
-            self.store.db_pool.simple_select_list(
-                table="devices",
-                keyvalues={},
-                retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+        db_result = cast(
+            List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
+            self.get_success(
+                self.store.db_pool.simple_select_list(
+                    table="devices",
+                    keyvalues={},
+                    retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+                ),
             ),
         )
         self.assertCountEqual(
             db_result,
             [
-                {
-                    "user_id": user_id,
-                    "device_id": device_id_1,
-                    "ip": "ip_1",
-                    "user_agent": "user_agent_1",
-                    "last_seen": 12345678000,
-                },
-                {
-                    "user_id": user_id,
-                    "device_id": device_id_2,
-                    "ip": "ip_2",
-                    "user_agent": "user_agent_2",
-                    "last_seen": 12345678000,
-                },
+                (user_id, "ip_1", "user_agent_1", device_id_1, 12345678000),
+                (user_id, "ip_2", "user_agent_2", device_id_2, 12345678000),
             ],
         )
 
@@ -385,28 +374,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         # Check that the new IP and user agent has not been stored yet
-        db_result = self.get_success(
-            self.store.db_pool.simple_select_list(
-                table="user_ips",
-                keyvalues={},
-                retcols=("access_token", "ip", "user_agent", "last_seen"),
+        db_result = cast(
+            List[Tuple[str, str, str, int]],
+            self.get_success(
+                self.store.db_pool.simple_select_list(
+                    table="user_ips",
+                    keyvalues={},
+                    retcols=("access_token", "ip", "user_agent", "last_seen"),
+                ),
             ),
         )
         self.assertEqual(
             db_result,
             [
-                {
-                    "access_token": "access_token",
-                    "ip": "ip_1",
-                    "user_agent": "user_agent_1",
-                    "last_seen": 12345678000,
-                },
-                {
-                    "access_token": "access_token",
-                    "ip": "ip_2",
-                    "user_agent": "user_agent_2",
-                    "last_seen": 12345678000,
-                },
+                ("access_token", "ip_1", "user_agent_1", 12345678000),
+                ("access_token", "ip_2", "user_agent_2", 12345678000),
             ],
         )
 
@@ -600,39 +582,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(200)
 
         # We should see that in the DB
-        result = self.get_success(
-            self.store.db_pool.simple_select_list(
-                table="user_ips",
-                keyvalues={"user_id": user_id},
-                retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
-                desc="get_user_ip_and_agents",
-            )
+        result = cast(
+            List[Tuple[str, str, str, Optional[str], int]],
+            self.get_success(
+                self.store.db_pool.simple_select_list(
+                    table="user_ips",
+                    keyvalues={"user_id": user_id},
+                    retcols=[
+                        "access_token",
+                        "ip",
+                        "user_agent",
+                        "device_id",
+                        "last_seen",
+                    ],
+                    desc="get_user_ip_and_agents",
+                )
+            ),
         )
 
         self.assertEqual(
             result,
-            [
-                {
-                    "access_token": "access_token",
-                    "ip": "ip",
-                    "user_agent": "user_agent",
-                    "device_id": device_id,
-                    "last_seen": 0,
-                }
-            ],
+            [("access_token", "ip", "user_agent", device_id, 0)],
         )
 
         # Now advance by a couple of months
         self.reactor.advance(60 * 24 * 60 * 60)
 
         # We should get no results.
-        result = self.get_success(
-            self.store.db_pool.simple_select_list(
-                table="user_ips",
-                keyvalues={"user_id": user_id},
-                retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
-                desc="get_user_ip_and_agents",
-            )
+        result = cast(
+            List[Tuple[str, str, str, Optional[str], int]],
+            self.get_success(
+                self.store.db_pool.simple_select_list(
+                    table="user_ips",
+                    keyvalues={"user_id": user_id},
+                    retcols=[
+                        "access_token",
+                        "ip",
+                        "user_agent",
+                        "device_id",
+                        "last_seen",
+                    ],
+                    desc="get_user_ip_and_agents",
+                )
+            ),
         )
 
         self.assertEqual(result, [])
@@ -696,28 +688,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(200)
 
         # We should see that in the DB
-        result = self.get_success(
-            self.store.db_pool.simple_select_list(
-                table="user_ips",
-                keyvalues={},
-                retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
-                desc="get_user_ip_and_agents",
-            )
+        result = cast(
+            List[Tuple[str, str, str, Optional[str], int]],
+            self.get_success(
+                self.store.db_pool.simple_select_list(
+                    table="user_ips",
+                    keyvalues={},
+                    retcols=[
+                        "access_token",
+                        "ip",
+                        "user_agent",
+                        "device_id",
+                        "last_seen",
+                    ],
+                    desc="get_user_ip_and_agents",
+                )
+            ),
         )
 
         # ensure user1 is filtered out
-        self.assertEqual(
-            result,
-            [
-                {
-                    "access_token": access_token2,
-                    "ip": "ip",
-                    "user_agent": "user_agent",
-                    "device_id": device_id2,
-                    "last_seen": 0,
-                }
-            ],
-        )
+        self.assertEqual(result, [(access_token2, "ip", "user_agent", device_id2, 0)])
 
 
 class ClientIpAuthTestCase(unittest.HomeserverTestCase):
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index f4c4661aaf..36fcab06b5 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -12,6 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import List, Optional, Tuple, cast
+
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import Membership
@@ -110,21 +112,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
     def test__null_byte_in_display_name_properly_handled(self) -> None:
         room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
 
-        res = self.get_success(
-            self.store.db_pool.simple_select_list(
-                "room_memberships",
-                {"user_id": "@alice:test"},
-                ["display_name", "event_id"],
-            )
+        res = cast(
+            List[Tuple[Optional[str], str]],
+            self.get_success(
+                self.store.db_pool.simple_select_list(
+                    "room_memberships",
+                    {"user_id": "@alice:test"},
+                    ["display_name", "event_id"],
+                )
+            ),
         )
         # Check that we only got one result back
         self.assertEqual(len(res), 1)
 
         # Check that alice's display name is "alice"
-        self.assertEqual(res[0]["display_name"], "alice")
+        self.assertEqual(res[0][0], "alice")
 
         # Grab the event_id to use later
-        event_id = res[0]["event_id"]
+        event_id = res[0][1]
 
         # Create a profile with the offending null byte in the display name
         new_profile = {"displayname": "ali\u0000ce"}
@@ -139,21 +144,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
             tok=self.t_alice,
         )
 
-        res2 = self.get_success(
-            self.store.db_pool.simple_select_list(
-                "room_memberships",
-                {"user_id": "@alice:test"},
-                ["display_name", "event_id"],
-            )
+        res2 = cast(
+            List[Tuple[Optional[str], str]],
+            self.get_success(
+                self.store.db_pool.simple_select_list(
+                    "room_memberships",
+                    {"user_id": "@alice:test"},
+                    ["display_name", "event_id"],
+                )
+            ),
         )
         # Check that we only have two results
         self.assertEqual(len(res2), 2)
 
         # Filter out the previous event using the event_id we grabbed above
-        row = [row for row in res2 if row["event_id"] != event_id]
+        row = [row for row in res2 if row[1] != event_id]
 
         # Check that alice's display name is now None
-        self.assertEqual(row[0]["display_name"], None)
+        self.assertIsNone(row[0][0])
 
     def test_room_is_locally_forgotten(self) -> None:
         """Test that when the last local user has forgotten a room it is known as forgotten."""
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 0b9446c36c..2715c73f16 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import logging
+from typing import List, Tuple, cast
 
 from immutabledict import immutabledict
 
@@ -584,18 +585,21 @@ class StateStoreTestCase(HomeserverTestCase):
         )
 
         # check that only state events are in state_groups, and all state events are in state_groups
-        res = self.get_success(
-            self.store.db_pool.simple_select_list(
-                table="state_groups",
-                keyvalues=None,
-                retcols=("event_id",),
-            )
+        res = cast(
+            List[Tuple[str]],
+            self.get_success(
+                self.store.db_pool.simple_select_list(
+                    table="state_groups",
+                    keyvalues=None,
+                    retcols=("event_id",),
+                )
+            ),
         )
 
         events = []
         for result in res:
-            self.assertNotIn(event3.event_id, result)
-            events.append(result.get("event_id"))
+            self.assertNotIn(event3.event_id, result)  # XXX
+            events.append(result[0])
 
         for event, _ in processed_events_and_context:
             if event.is_state():
@@ -606,23 +610,29 @@ class StateStoreTestCase(HomeserverTestCase):
         # has an entry and prev event in state_group_edges
         for event, context in processed_events_and_context:
             if event.is_state():
-                state = self.get_success(
-                    self.store.db_pool.simple_select_list(
-                        table="state_groups_state",
-                        keyvalues={"state_group": context.state_group_after_event},
-                        retcols=("type", "state_key"),
-                    )
-                )
-                self.assertEqual(event.type, state[0].get("type"))
-                self.assertEqual(event.state_key, state[0].get("state_key"))
-
-                groups = self.get_success(
-                    self.store.db_pool.simple_select_list(
-                        table="state_group_edges",
-                        keyvalues={"state_group": str(context.state_group_after_event)},
-                        retcols=("*",),
-                    )
+                state = cast(
+                    List[Tuple[str, str]],
+                    self.get_success(
+                        self.store.db_pool.simple_select_list(
+                            table="state_groups_state",
+                            keyvalues={"state_group": context.state_group_after_event},
+                            retcols=("type", "state_key"),
+                        )
+                    ),
                 )
-                self.assertEqual(
-                    context.state_group_before_event, groups[0].get("prev_state_group")
+                self.assertEqual(event.type, state[0][0])
+                self.assertEqual(event.state_key, state[0][1])
+
+                groups = cast(
+                    List[Tuple[str]],
+                    self.get_success(
+                        self.store.db_pool.simple_select_list(
+                            table="state_group_edges",
+                            keyvalues={
+                                "state_group": str(context.state_group_after_event)
+                            },
+                            retcols=("prev_state_group",),
+                        )
+                    ),
                 )
+                self.assertEqual(context.state_group_before_event, groups[0][0])
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 8c72aa1722..822c41dd9f 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import re
-from typing import Any, Dict, Set, Tuple
+from typing import Any, Dict, List, Optional, Set, Tuple, cast
 from unittest import mock
 from unittest.mock import Mock, patch
 
@@ -62,14 +62,13 @@ class GetUserDirectoryTables:
         Returns a list of tuples (user_id, room_id) where room_id is public and
         contains the user with the given id.
         """
-        r = await self.store.db_pool.simple_select_list(
-            "users_in_public_rooms", None, ("user_id", "room_id")
+        r = cast(
+            List[Tuple[str, str]],
+            await self.store.db_pool.simple_select_list(
+                "users_in_public_rooms", None, ("user_id", "room_id")
+            ),
         )
-
-        retval = set()
-        for i in r:
-            retval.add((i["user_id"], i["room_id"]))
-        return retval
+        return set(r)
 
     async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]:
         """Fetch the entire `users_who_share_private_rooms` table.
@@ -78,27 +77,30 @@ class GetUserDirectoryTables:
         to the rows of `users_who_share_private_rooms`.
         """
 
-        rows = await self.store.db_pool.simple_select_list(
-            "users_who_share_private_rooms",
-            None,
-            ["user_id", "other_user_id", "room_id"],
+        rows = cast(
+            List[Tuple[str, str, str]],
+            await self.store.db_pool.simple_select_list(
+                "users_who_share_private_rooms",
+                None,
+                ["user_id", "other_user_id", "room_id"],
+            ),
         )
-        rv = set()
-        for row in rows:
-            rv.add((row["user_id"], row["other_user_id"], row["room_id"]))
-        return rv
+        return set(rows)
 
     async def get_users_in_user_directory(self) -> Set[str]:
         """Fetch the set of users in the `user_directory` table.
 
         This is useful when checking we've correctly excluded users from the directory.
         """
-        result = await self.store.db_pool.simple_select_list(
-            "user_directory",
-            None,
-            ["user_id"],
+        result = cast(
+            List[Tuple[str]],
+            await self.store.db_pool.simple_select_list(
+                "user_directory",
+                None,
+                ["user_id"],
+            ),
         )
-        return {row["user_id"] for row in result}
+        return {row[0] for row in result}
 
     async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]:
         """Fetch users and their profiles from the `user_directory` table.
@@ -107,16 +109,17 @@ class GetUserDirectoryTables:
         It's almost the entire contents of the `user_directory` table: the only
         thing missing is an unused room_id column.
         """
-        rows = await self.store.db_pool.simple_select_list(
-            "user_directory",
-            None,
-            ("user_id", "display_name", "avatar_url"),
+        rows = cast(
+            List[Tuple[str, Optional[str], Optional[str]]],
+            await self.store.db_pool.simple_select_list(
+                "user_directory",
+                None,
+                ("user_id", "display_name", "avatar_url"),
+            ),
         )
         return {
-            row["user_id"]: ProfileInfo(
-                display_name=row["display_name"], avatar_url=row["avatar_url"]
-            )
-            for row in rows
+            user_id: ProfileInfo(display_name=display_name, avatar_url=avatar_url)
+            for user_id, display_name, avatar_url in rows
         }
 
     async def get_tables(