summary refs log tree commit diff
path: root/synapse/storage/databases/main/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/state.py')
-rw-r--r--synapse/storage/databases/main/state.py62
1 files changed, 38 insertions, 24 deletions
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5eaaff5b68..598025dd91 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -20,10 +20,12 @@ from typing import (
     Collection,
     Dict,
     Iterable,
+    List,
     Mapping,
     Optional,
     Set,
     Tuple,
+    cast,
 )
 
 import attr
@@ -388,16 +390,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         Raises:
              RuntimeError if the state is unknown at any of the given events
         """
-        rows = await self.db_pool.simple_select_many_batch(
-            table="event_to_state_groups",
-            column="event_id",
-            iterable=event_ids,
-            keyvalues={},
-            retcols=("event_id", "state_group"),
-            desc="_get_state_group_for_events",
+        rows = cast(
+            List[Tuple[str, int]],
+            await self.db_pool.simple_select_many_batch(
+                table="event_to_state_groups",
+                column="event_id",
+                iterable=event_ids,
+                keyvalues={},
+                retcols=("event_id", "state_group"),
+                desc="_get_state_group_for_events",
+            ),
         )
 
-        res = {row["event_id"]: row["state_group"] for row in rows}
+        res = dict(rows)
         for e in event_ids:
             if e not in res:
                 raise RuntimeError("No state group for unknown or outlier event %s" % e)
@@ -415,16 +420,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             The subset of state groups that are referenced.
         """
 
-        rows = await self.db_pool.simple_select_many_batch(
-            table="event_to_state_groups",
-            column="state_group",
-            iterable=state_groups,
-            keyvalues={},
-            retcols=("DISTINCT state_group",),
-            desc="get_referenced_state_groups",
+        rows = cast(
+            List[Tuple[int]],
+            await self.db_pool.simple_select_many_batch(
+                table="event_to_state_groups",
+                column="state_group",
+                iterable=state_groups,
+                keyvalues={},
+                retcols=("DISTINCT state_group",),
+                desc="get_referenced_state_groups",
+            ),
         )
 
-        return {row["state_group"] for row in rows}
+        return {row[0] for row in rows}
 
     async def update_state_for_partial_state_event(
         self,
@@ -624,16 +632,22 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
             # potentially stale, since there may have been a period where the
             # server didn't share a room with the remote user and therefore may
             # have missed any device updates.
-            rows = self.db_pool.simple_select_many_txn(
-                txn,
-                table="current_state_events",
-                column="room_id",
-                iterable=to_delete,
-                keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
-                retcols=("state_key",),
+            rows = cast(
+                List[Tuple[str]],
+                self.db_pool.simple_select_many_txn(
+                    txn,
+                    table="current_state_events",
+                    column="room_id",
+                    iterable=to_delete,
+                    keyvalues={
+                        "type": EventTypes.Member,
+                        "membership": Membership.JOIN,
+                    },
+                    retcols=("state_key",),
+                ),
             )
 
-            potentially_left_users = {row["state_key"] for row in rows}
+            potentially_left_users = {row[0] for row in rows}
 
             # Now lets actually delete the rooms from the DB.
             self.db_pool.simple_delete_many_txn(