diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index f4c4661aaf..36fcab06b5 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Optional, Tuple, cast
+
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import Membership
@@ -110,21 +112,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
def test__null_byte_in_display_name_properly_handled(self) -> None:
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
- res = self.get_success(
- self.store.db_pool.simple_select_list(
- "room_memberships",
- {"user_id": "@alice:test"},
- ["display_name", "event_id"],
- )
+ res = cast(
+ List[Tuple[Optional[str], str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ "room_memberships",
+ {"user_id": "@alice:test"},
+ ["display_name", "event_id"],
+ )
+ ),
)
# Check that we only got one result back
self.assertEqual(len(res), 1)
# Check that alice's display name is "alice"
- self.assertEqual(res[0]["display_name"], "alice")
+ self.assertEqual(res[0][0], "alice")
# Grab the event_id to use later
- event_id = res[0]["event_id"]
+ event_id = res[0][1]
# Create a profile with the offending null byte in the display name
new_profile = {"displayname": "ali\u0000ce"}
@@ -139,21 +144,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
tok=self.t_alice,
)
- res2 = self.get_success(
- self.store.db_pool.simple_select_list(
- "room_memberships",
- {"user_id": "@alice:test"},
- ["display_name", "event_id"],
- )
+ res2 = cast(
+ List[Tuple[Optional[str], str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ "room_memberships",
+ {"user_id": "@alice:test"},
+ ["display_name", "event_id"],
+ )
+ ),
)
# Check that we only have two results
self.assertEqual(len(res2), 2)
# Filter out the previous event using the event_id we grabbed above
- row = [row for row in res2 if row["event_id"] != event_id]
+ row = [row for row in res2 if row[1] != event_id]
# Check that alice's display name is now None
- self.assertEqual(row[0]["display_name"], None)
+ self.assertIsNone(row[0][0])
def test_room_is_locally_forgotten(self) -> None:
"""Test that when the last local user has forgotten a room it is known as forgotten."""
|