diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/handlers/test_stats.py | 14 | ||||
-rw-r--r-- | tests/storage/databases/main/test_receipts.py | 20 | ||||
-rw-r--r-- | tests/storage/test__base.py | 16 | ||||
-rw-r--r-- | tests/storage/test_background_update.py | 35 | ||||
-rw-r--r-- | tests/storage/test_base.py | 4 | ||||
-rw-r--r-- | tests/storage/test_client_ips.py | 250 | ||||
-rw-r--r-- | tests/storage/test_roommember.py | 40 | ||||
-rw-r--r-- | tests/storage/test_state.py | 62 | ||||
-rw-r--r-- | tests/storage/test_user_directory.py | 61 |
9 files changed, 262 insertions, 240 deletions
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index d11ded6c5b..76c56d5434 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, cast from twisted.test.proto_helpers import MemoryReactor @@ -68,10 +68,14 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - async def get_all_room_state(self) -> List[Dict[str, Any]]: - return await self.store.db_pool.simple_select_list( - "room_stats_state", None, retcols=("name", "topic", "canonical_alias") + async def get_all_room_state(self) -> List[Optional[str]]: + rows = cast( + List[Tuple[Optional[str]]], + await self.store.db_pool.simple_select_list( + "room_stats_state", None, retcols=("topic",) + ), ) + return [r[0] for r in rows] def _get_current_stats( self, stats_type: str, stat_id: str @@ -130,7 +134,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): r = self.get_success(self.get_all_room_state()) self.assertEqual(len(r), 1) - self.assertEqual(r[0]["topic"], "foo") + self.assertEqual(r[0], "foo") def test_create_user(self) -> None: """ diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py index 71db47405e..98b01086bc 100644 --- a/tests/storage/databases/main/test_receipts.py +++ b/tests/storage/databases/main/test_receipts.py @@ -117,7 +117,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): if expected_row is not None: columns += expected_row.keys() - rows = self.get_success( + row_tuples = self.get_success( self.store.db_pool.simple_select_list( table=table, keyvalues={ @@ -134,22 +134,22 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): if expected_row is not None: self.assertEqual( - len(rows), + len(row_tuples), 1, f"Background update did not leave behind latest receipt in {table}", ) self.assertEqual( - rows[0], - { - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - **expected_row, - }, + row_tuples[0], + ( + room_id, + receipt_type, + user_id, + *expected_row.values(), + ), ) else: self.assertEqual( - len(rows), + len(row_tuples), 0, f"Background update did not remove all duplicate receipts from {table}", ) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 8bbf936ae9..8cbc974ac4 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -14,7 +14,7 @@ # limitations under the License. import secrets -from typing import Generator, Tuple +from typing import Generator, List, Tuple, cast from twisted.test.proto_helpers import MemoryReactor @@ -47,15 +47,15 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase): ) def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]: - res = self.get_success( - self.storage.db_pool.simple_select_list( - self.table_name, None, ["id, username, value"] - ) + yield from cast( + List[Tuple[int, str, str]], + self.get_success( + self.storage.db_pool.simple_select_list( + self.table_name, None, ["id, username, value"] + ) + ), ) - for i in res: - yield (i["id"], i["username"], i["value"]) - def test_upsert_many(self) -> None: """ Upsert_many will perform the upsert operation across a batch of data. diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index abf7d0564d..3f5bfa09d4 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List, Tuple, cast from unittest.mock import AsyncMock, Mock import yaml @@ -526,15 +527,18 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): self.wait_for_background_updates() # Check the correct values are in the new table. - rows = self.get_success( - self.store.db_pool.simple_select_list( - table="test_constraint", - keyvalues={}, - retcols=("a", "b"), - ) + rows = cast( + List[Tuple[int, int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="test_constraint", + keyvalues={}, + retcols=("a", "b"), + ) + ), ) - self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}]) + self.assertCountEqual(rows, [(1, 1), (3, 3)]) # And check that invalid rows get correctly rejected. self.get_failure( @@ -640,14 +644,17 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): self.wait_for_background_updates() # Check the correct values are in the new table. - rows = self.get_success( - self.store.db_pool.simple_select_list( - table="test_constraint", - keyvalues={}, - retcols=("a", "b"), - ) + rows = cast( + List[Tuple[int, int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="test_constraint", + keyvalues={}, + retcols=("a", "b"), + ) + ), ) - self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}]) + self.assertCountEqual(rows, [(1, 1), (3, 3)]) # And check that invalid rows get correctly rejected. self.get_failure( diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 256d28e4c9..e4a52c301e 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -146,7 +146,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]: self.mock_txn.rowcount = 3 - self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) + self.mock_txn.fetchall.return_value = [(1,), (2,), (3,)] self.mock_txn.description = (("colA", None, None, None, None, None, None),) ret = yield defer.ensureDeferred( @@ -155,7 +155,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) + self.assertEqual([(1,), (2,), (3,)], ret) self.mock_txn.execute.assert_called_with( "SELECT colA FROM tablename WHERE keycol = ?", ["A set"] ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 0c054a598f..8e4393d843 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any, Dict, List, Optional, Tuple, cast from unittest.mock import AsyncMock from parameterized import parameterized @@ -97,26 +97,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(200) self.pump(0) - result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={"user_id": user_id}, - retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], - desc="get_user_ip_and_agents", - ) + result = cast( + List[Tuple[str, str, str, Optional[str], int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=[ + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ], + desc="get_user_ip_and_agents", + ) + ), ) self.assertEqual( - result, - [ - { - "access_token": "access_token", - "ip": "ip", - "user_agent": "user_agent", - "device_id": None, - "last_seen": 12345678000, - } - ], + result, [("access_token", "ip", "user_agent", None, 12345678000)] ) # Add another & trigger the storage loop @@ -128,26 +128,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(10) self.pump(0) - result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={"user_id": user_id}, - retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], - desc="get_user_ip_and_agents", - ) + result = cast( + List[Tuple[str, str, str, Optional[str], int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=[ + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ], + desc="get_user_ip_and_agents", + ) + ), ) # Only one result, has been upserted. self.assertEqual( - result, - [ - { - "access_token": "access_token", - "ip": "ip", - "user_agent": "user_agent", - "device_id": None, - "last_seen": 12345878000, - } - ], + result, [("access_token", "ip", "user_agent", None, 12345878000)] ) @parameterized.expand([(False,), (True,)]) @@ -177,25 +177,23 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(10) else: # Check that the new IP and user agent has not been stored yet - db_result = self.get_success( - self.store.db_pool.simple_select_list( - table="devices", - keyvalues={}, - retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + db_result = cast( + List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]], + self.get_success( + self.store.db_pool.simple_select_list( + table="devices", + keyvalues={}, + retcols=( + "user_id", + "ip", + "user_agent", + "device_id", + "last_seen", + ), + ), ), ) - self.assertEqual( - db_result, - [ - { - "user_id": user_id, - "device_id": device_id, - "ip": None, - "user_agent": None, - "last_seen": None, - }, - ], - ) + self.assertEqual(db_result, [(user_id, None, None, device_id, None)]) result = self.get_success( self.store.get_last_client_ip_by_device(user_id, device_id) @@ -261,30 +259,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) # Check that the new IP and user agent has not been stored yet - db_result = self.get_success( - self.store.db_pool.simple_select_list( - table="devices", - keyvalues={}, - retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + db_result = cast( + List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]], + self.get_success( + self.store.db_pool.simple_select_list( + table="devices", + keyvalues={}, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + ), ), ) self.assertCountEqual( db_result, [ - { - "user_id": user_id, - "device_id": device_id_1, - "ip": "ip_1", - "user_agent": "user_agent_1", - "last_seen": 12345678000, - }, - { - "user_id": user_id, - "device_id": device_id_2, - "ip": "ip_2", - "user_agent": "user_agent_2", - "last_seen": 12345678000, - }, + (user_id, "ip_1", "user_agent_1", device_id_1, 12345678000), + (user_id, "ip_2", "user_agent_2", device_id_2, 12345678000), ], ) @@ -385,28 +374,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) # Check that the new IP and user agent has not been stored yet - db_result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={}, - retcols=("access_token", "ip", "user_agent", "last_seen"), + db_result = cast( + List[Tuple[str, str, str, int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={}, + retcols=("access_token", "ip", "user_agent", "last_seen"), + ), ), ) self.assertEqual( db_result, [ - { - "access_token": "access_token", - "ip": "ip_1", - "user_agent": "user_agent_1", - "last_seen": 12345678000, - }, - { - "access_token": "access_token", - "ip": "ip_2", - "user_agent": "user_agent_2", - "last_seen": 12345678000, - }, + ("access_token", "ip_1", "user_agent_1", 12345678000), + ("access_token", "ip_2", "user_agent_2", 12345678000), ], ) @@ -600,39 +582,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(200) # We should see that in the DB - result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={"user_id": user_id}, - retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], - desc="get_user_ip_and_agents", - ) + result = cast( + List[Tuple[str, str, str, Optional[str], int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=[ + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ], + desc="get_user_ip_and_agents", + ) + ), ) self.assertEqual( result, - [ - { - "access_token": "access_token", - "ip": "ip", - "user_agent": "user_agent", - "device_id": device_id, - "last_seen": 0, - } - ], + [("access_token", "ip", "user_agent", device_id, 0)], ) # Now advance by a couple of months self.reactor.advance(60 * 24 * 60 * 60) # We should get no results. - result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={"user_id": user_id}, - retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], - desc="get_user_ip_and_agents", - ) + result = cast( + List[Tuple[str, str, str, Optional[str], int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=[ + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ], + desc="get_user_ip_and_agents", + ) + ), ) self.assertEqual(result, []) @@ -696,28 +688,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(200) # We should see that in the DB - result = self.get_success( - self.store.db_pool.simple_select_list( - table="user_ips", - keyvalues={}, - retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], - desc="get_user_ip_and_agents", - ) + result = cast( + List[Tuple[str, str, str, Optional[str], int]], + self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={}, + retcols=[ + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ], + desc="get_user_ip_and_agents", + ) + ), ) # ensure user1 is filtered out - self.assertEqual( - result, - [ - { - "access_token": access_token2, - "ip": "ip", - "user_agent": "user_agent", - "device_id": device_id2, - "last_seen": 0, - } - ], - ) + self.assertEqual(result, [(access_token2, "ip", "user_agent", device_id2, 0)]) class ClientIpAuthTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index f4c4661aaf..36fcab06b5 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional, Tuple, cast + from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import Membership @@ -110,21 +112,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): def test__null_byte_in_display_name_properly_handled(self) -> None: room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) - res = self.get_success( - self.store.db_pool.simple_select_list( - "room_memberships", - {"user_id": "@alice:test"}, - ["display_name", "event_id"], - ) + res = cast( + List[Tuple[Optional[str], str]], + self.get_success( + self.store.db_pool.simple_select_list( + "room_memberships", + {"user_id": "@alice:test"}, + ["display_name", "event_id"], + ) + ), ) # Check that we only got one result back self.assertEqual(len(res), 1) # Check that alice's display name is "alice" - self.assertEqual(res[0]["display_name"], "alice") + self.assertEqual(res[0][0], "alice") # Grab the event_id to use later - event_id = res[0]["event_id"] + event_id = res[0][1] # Create a profile with the offending null byte in the display name new_profile = {"displayname": "ali\u0000ce"} @@ -139,21 +144,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): tok=self.t_alice, ) - res2 = self.get_success( - self.store.db_pool.simple_select_list( - "room_memberships", - {"user_id": "@alice:test"}, - ["display_name", "event_id"], - ) + res2 = cast( + List[Tuple[Optional[str], str]], + self.get_success( + self.store.db_pool.simple_select_list( + "room_memberships", + {"user_id": "@alice:test"}, + ["display_name", "event_id"], + ) + ), ) # Check that we only have two results self.assertEqual(len(res2), 2) # Filter out the previous event using the event_id we grabbed above - row = [row for row in res2 if row["event_id"] != event_id] + row = [row for row in res2 if row[1] != event_id] # Check that alice's display name is now None - self.assertEqual(row[0]["display_name"], None) + self.assertIsNone(row[0][0]) def test_room_is_locally_forgotten(self) -> None: """Test that when the last local user has forgotten a room it is known as forgotten.""" diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 0b9446c36c..2715c73f16 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from typing import List, Tuple, cast from immutabledict import immutabledict @@ -584,18 +585,21 @@ class StateStoreTestCase(HomeserverTestCase): ) # check that only state events are in state_groups, and all state events are in state_groups - res = self.get_success( - self.store.db_pool.simple_select_list( - table="state_groups", - keyvalues=None, - retcols=("event_id",), - ) + res = cast( + List[Tuple[str]], + self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups", + keyvalues=None, + retcols=("event_id",), + ) + ), ) events = [] for result in res: - self.assertNotIn(event3.event_id, result) - events.append(result.get("event_id")) + self.assertNotIn(event3.event_id, result) # XXX + events.append(result[0]) for event, _ in processed_events_and_context: if event.is_state(): @@ -606,23 +610,29 @@ class StateStoreTestCase(HomeserverTestCase): # has an entry and prev event in state_group_edges for event, context in processed_events_and_context: if event.is_state(): - state = self.get_success( - self.store.db_pool.simple_select_list( - table="state_groups_state", - keyvalues={"state_group": context.state_group_after_event}, - retcols=("type", "state_key"), - ) - ) - self.assertEqual(event.type, state[0].get("type")) - self.assertEqual(event.state_key, state[0].get("state_key")) - - groups = self.get_success( - self.store.db_pool.simple_select_list( - table="state_group_edges", - keyvalues={"state_group": str(context.state_group_after_event)}, - retcols=("*",), - ) + state = cast( + List[Tuple[str, str]], + self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups_state", + keyvalues={"state_group": context.state_group_after_event}, + retcols=("type", "state_key"), + ) + ), ) - self.assertEqual( - context.state_group_before_event, groups[0].get("prev_state_group") + self.assertEqual(event.type, state[0][0]) + self.assertEqual(event.state_key, state[0][1]) + + groups = cast( + List[Tuple[str]], + self.get_success( + self.store.db_pool.simple_select_list( + table="state_group_edges", + keyvalues={ + "state_group": str(context.state_group_after_event) + }, + retcols=("prev_state_group",), + ) + ), ) + self.assertEqual(context.state_group_before_event, groups[0][0]) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 8c72aa1722..822c41dd9f 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import re -from typing import Any, Dict, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, cast from unittest import mock from unittest.mock import Mock, patch @@ -62,14 +62,13 @@ class GetUserDirectoryTables: Returns a list of tuples (user_id, room_id) where room_id is public and contains the user with the given id. """ - r = await self.store.db_pool.simple_select_list( - "users_in_public_rooms", None, ("user_id", "room_id") + r = cast( + List[Tuple[str, str]], + await self.store.db_pool.simple_select_list( + "users_in_public_rooms", None, ("user_id", "room_id") + ), ) - - retval = set() - for i in r: - retval.add((i["user_id"], i["room_id"])) - return retval + return set(r) async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]: """Fetch the entire `users_who_share_private_rooms` table. @@ -78,27 +77,30 @@ class GetUserDirectoryTables: to the rows of `users_who_share_private_rooms`. """ - rows = await self.store.db_pool.simple_select_list( - "users_who_share_private_rooms", - None, - ["user_id", "other_user_id", "room_id"], + rows = cast( + List[Tuple[str, str, str]], + await self.store.db_pool.simple_select_list( + "users_who_share_private_rooms", + None, + ["user_id", "other_user_id", "room_id"], + ), ) - rv = set() - for row in rows: - rv.add((row["user_id"], row["other_user_id"], row["room_id"])) - return rv + return set(rows) async def get_users_in_user_directory(self) -> Set[str]: """Fetch the set of users in the `user_directory` table. This is useful when checking we've correctly excluded users from the directory. """ - result = await self.store.db_pool.simple_select_list( - "user_directory", - None, - ["user_id"], + result = cast( + List[Tuple[str]], + await self.store.db_pool.simple_select_list( + "user_directory", + None, + ["user_id"], + ), ) - return {row["user_id"] for row in result} + return {row[0] for row in result} async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]: """Fetch users and their profiles from the `user_directory` table. @@ -107,16 +109,17 @@ class GetUserDirectoryTables: It's almost the entire contents of the `user_directory` table: the only thing missing is an unused room_id column. """ - rows = await self.store.db_pool.simple_select_list( - "user_directory", - None, - ("user_id", "display_name", "avatar_url"), + rows = cast( + List[Tuple[str, Optional[str], Optional[str]]], + await self.store.db_pool.simple_select_list( + "user_directory", + None, + ("user_id", "display_name", "avatar_url"), + ), ) return { - row["user_id"]: ProfileInfo( - display_name=row["display_name"], avatar_url=row["avatar_url"] - ) - for row in rows + user_id: ProfileInfo(display_name=display_name, avatar_url=avatar_url) + for user_id, display_name, avatar_url in rows } async def get_tables( |