diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 3755773faa..1ed7f2d0ef 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -27,6 +27,7 @@ from typing import (
Set,
Tuple,
Union,
+ cast,
)
import attr
@@ -275,7 +276,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
_get_users_in_room_with_profiles,
)
- @cached(max_entries=100000)
+ @cached(max_entries=100000) # type: ignore[synapse-@cached-mutable]
async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
"""Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
@@ -481,6 +482,22 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
desc="get_local_users_in_room",
)
+ async def get_local_users_related_to_room(
+ self, room_id: str
+ ) -> List[Tuple[str, str]]:
+ """
+ Retrieves a list of the current roommembers who are local to the server and their membership status.
+ """
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="local_current_membership",
+ keyvalues={"room_id": room_id},
+ retcols=("user_id", "membership"),
+ desc="get_local_users_in_room",
+ ),
+ )
+
async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool:
"""
Check whether a given local user is currently joined to the given room.
@@ -683,25 +700,28 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
Map from user_id to set of rooms that is currently in.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="current_state_events",
- column="state_key",
- iterable=user_ids,
- retcols=(
- "state_key",
- "room_id",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="current_state_events",
+ column="state_key",
+ iterable=user_ids,
+ retcols=(
+ "state_key",
+ "room_id",
+ ),
+ keyvalues={
+ "type": EventTypes.Member,
+ "membership": Membership.JOIN,
+ },
+ desc="get_rooms_for_users",
),
- keyvalues={
- "type": EventTypes.Member,
- "membership": Membership.JOIN,
- },
- desc="get_rooms_for_users",
)
user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
- for row in rows:
- user_rooms[row["state_key"]].add(row["room_id"])
+ for state_key, room_id in rows:
+ user_rooms[state_key].add(room_id)
return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
@@ -892,17 +912,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
Map from event ID to `user_id`, or None if event is not a join.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="room_memberships",
- column="event_id",
- iterable=event_ids,
- retcols=("user_id", "event_id"),
- keyvalues={"membership": Membership.JOIN},
- batch_size=1000,
- desc="_get_user_ids_from_membership_event_ids",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=event_ids,
+ retcols=("event_id", "user_id"),
+ keyvalues={"membership": Membership.JOIN},
+ batch_size=1000,
+ desc="_get_user_ids_from_membership_event_ids",
+ ),
)
- return {row["event_id"]: row["user_id"] for row in rows}
+ return dict(rows)
@cached(max_entries=10000)
async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -933,7 +956,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
like_clause = "%:" + host
rows = await self.db_pool.execute(
- "is_host_joined", None, sql, membership, room_id, like_clause
+ "is_host_joined", sql, membership, room_id, like_clause
)
if not rows:
@@ -1063,15 +1086,19 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
for fully-joined rooms.
"""
- rows = await self.db_pool.simple_select_list(
- "current_state_events",
- keyvalues={"room_id": room_id},
- retcols=("event_id", "membership"),
- desc="has_completed_background_updates",
+ rows = cast(
+ List[Tuple[str, Optional[str]]],
+ await self.db_pool.simple_select_list(
+ "current_state_events",
+ keyvalues={"room_id": room_id},
+ retcols=("event_id", "membership"),
+ desc="has_completed_background_updates",
+ ),
)
- return {row["event_id"]: row["membership"] for row in rows}
+ return dict(rows)
- @cached(max_entries=10000)
+ # TODO This returns a mutable object, which is generally confusing when using a cache.
+ @cached(max_entries=10000) # type: ignore[synapse-@cached-mutable]
def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
return _JoinedHostsCache()
@@ -1157,7 +1184,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
AND forgotten = 0;
"""
- rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+ rows = await self.db_pool.execute("is_forgotten_room", sql, room_id)
# `count(*)` returns always an integer
# If any rows still exist it means someone has not forgotten this room yet
@@ -1201,21 +1228,22 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
membership event, otherwise the value is None.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="room_memberships",
- column="event_id",
- iterable=member_event_ids,
- retcols=("user_id", "membership", "event_id"),
- keyvalues={},
- batch_size=500,
- desc="get_membership_from_event_ids",
+ rows = cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=member_event_ids,
+ retcols=("user_id", "membership", "event_id"),
+ keyvalues={},
+ batch_size=500,
+ desc="get_membership_from_event_ids",
+ ),
)
return {
- row["event_id"]: EventIdMembership(
- membership=row["membership"], user_id=row["user_id"]
- )
- for row in rows
+ event_id: EventIdMembership(membership=membership, user_id=user_id)
+ for user_id, membership, event_id in rows
}
async def is_local_host_in_room_ignoring_users(
@@ -1348,18 +1376,16 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return 0
- min_stream_id = rows[-1]["stream_ordering"]
+ min_stream_id = rows[-1][0]
to_update = []
- for row in rows:
- event_id = row["event_id"]
- room_id = row["room_id"]
+ for _, event_id, room_id, json in rows:
try:
- event_json = db_to_json(row["json"])
+ event_json = db_to_json(json)
content = event_json["content"]
except Exception:
continue
|