diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e2cccc688c..f02c1d7ea7 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,7 +15,6 @@
import logging
from typing import (
TYPE_CHECKING,
- Callable,
Collection,
Dict,
FrozenSet,
@@ -31,12 +30,8 @@ from typing import (
import attr
from synapse.api.constants import EventTypes, Membership
-from synapse.events import EventBase
from synapse.metrics import LaterGauge
-from synapse.metrics.background_process_metrics import (
- run_as_background_process,
- wrap_as_background_process,
-)
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -91,15 +86,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# at a time. Keyed by room_id.
self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
- # Is the current_state_events.membership up to date? Or is the
- # background update still running?
- self._current_state_events_membership_up_to_date = False
-
- txn = db_conn.cursor(
- txn_name="_check_safe_current_state_events_membership_updated"
- )
- self._check_safe_current_state_events_membership_updated_txn(txn)
- txn.close()
+ self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
if (
self.hs.config.worker.run_background_tasks
@@ -157,61 +144,42 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self._known_servers_count = max([count, 1])
return self._known_servers_count
- def _check_safe_current_state_events_membership_updated_txn(
- self, txn: LoggingTransaction
- ) -> None:
- """Checks if it is safe to assume the new current_state_events
- membership column is up to date
- """
-
- pending_update = self.db_pool.simple_select_one_txn(
- txn,
- table="background_updates",
- keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
- retcols=["update_name"],
- allow_none=True,
- )
-
- self._current_state_events_membership_up_to_date = not pending_update
-
- # If the update is still running, reschedule to run.
- if pending_update:
- self._clock.call_later(
- 15.0,
- run_as_background_process,
- "_check_safe_current_state_events_membership_updated",
- self.db_pool.runInteraction,
- "_check_safe_current_state_events_membership_updated",
- self._check_safe_current_state_events_membership_updated_txn,
- )
-
@cached(max_entries=100000, iterable=True)
async def get_users_in_room(self, room_id: str) -> List[str]:
- return await self.db_pool.runInteraction(
- "get_users_in_room", self.get_users_in_room_txn, room_id
+ """Returns a list of users in the room.
+
+ Will return inaccurate results for rooms with partial state, since the state for
+ the forward extremities of those rooms will exclude most members. We may also
+ calculate room state incorrectly for such rooms and believe that a member is or
+ is not in the room when the opposite is true.
+
+ Note: If you only care about users in the room local to the homeserver, use
+ `get_local_users_in_room(...)` instead which will be more performant.
+ """
+ return await self.db_pool.simple_select_onecol(
+ table="current_state_events",
+ keyvalues={
+ "type": EventTypes.Member,
+ "room_id": room_id,
+ "membership": Membership.JOIN,
+ },
+ retcol="state_key",
+ desc="get_users_in_room",
)
def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
- # If we can assume current_state_events.membership is up to date
- # then we can avoid a join, which is a Very Good Thing given how
- # frequently this function gets called.
- if self._current_state_events_membership_up_to_date:
- sql = """
- SELECT state_key FROM current_state_events
- WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
- """
- else:
- sql = """
- SELECT state_key FROM room_memberships as m
- INNER JOIN current_state_events as c
- ON m.event_id = c.event_id
- AND m.room_id = c.room_id
- AND m.user_id = c.state_key
- WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
- """
+ """Returns a list of users in the room."""
- txn.execute(sql, (room_id, Membership.JOIN))
- return [r[0] for r in txn]
+ return self.db_pool.simple_select_onecol_txn(
+ txn,
+ table="current_state_events",
+ keyvalues={
+ "type": EventTypes.Member,
+ "room_id": room_id,
+ "membership": Membership.JOIN,
+ },
+ retcol="state_key",
+ )
@cached()
def get_user_in_room_with_profile(
@@ -283,6 +251,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Returns:
A mapping from user ID to ProfileInfo.
+
+ Preconditions:
+ - There is full state available for the room (it is not partial-stated).
"""
def _get_users_in_room_with_profiles(
@@ -322,28 +293,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We do this all in one transaction to keep the cache small.
# FIXME: get rid of this when we have room_stats
- # If we can assume current_state_events.membership is up to date
- # then we can avoid a join, which is a Very Good Thing given how
- # frequently this function gets called.
- if self._current_state_events_membership_up_to_date:
- # Note, rejected events will have a null membership field, so
- # we we manually filter them out.
- sql = """
- SELECT count(*), membership FROM current_state_events
- WHERE type = 'm.room.member' AND room_id = ?
- AND membership IS NOT NULL
- GROUP BY membership
- """
- else:
- sql = """
- SELECT count(*), m.membership FROM room_memberships as m
- INNER JOIN current_state_events as c
- ON m.event_id = c.event_id
- AND m.room_id = c.room_id
- AND m.user_id = c.state_key
- WHERE c.type = 'm.room.member' AND c.room_id = ?
- GROUP BY m.membership
- """
+ # Note, rejected events will have a null membership field, so
+ # we we manually filter them out.
+ sql = """
+ SELECT count(*), membership FROM current_state_events
+ WHERE type = 'm.room.member' AND room_id = ?
+ AND membership IS NOT NULL
+ GROUP BY membership
+ """
txn.execute(sql, (room_id,))
res: Dict[str, MemberSummary] = {}
@@ -352,30 +309,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# we order by membership and then fairly arbitrarily by event_id so
# heroes are consistent
- if self._current_state_events_membership_up_to_date:
- # Note, rejected events will have a null membership field, so
- # we we manually filter them out.
- sql = """
- SELECT state_key, membership, event_id
- FROM current_state_events
- WHERE type = 'm.room.member' AND room_id = ?
- AND membership IS NOT NULL
- ORDER BY
- CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
- event_id ASC
- LIMIT ?
- """
- else:
- sql = """
- SELECT c.state_key, m.membership, c.event_id
- FROM room_memberships as m
- INNER JOIN current_state_events as c USING (room_id, event_id)
- WHERE c.type = 'm.room.member' AND c.room_id = ?
- ORDER BY
- CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
- c.event_id ASC
- LIMIT ?
- """
+ # Note, rejected events will have a null membership field, so
+ # we we manually filter them out.
+ sql = """
+ SELECT state_key, membership, event_id
+ FROM current_state_events
+ WHERE type = 'm.room.member' AND room_id = ?
+ AND membership IS NOT NULL
+ ORDER BY
+ CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
+ event_id ASC
+ LIMIT ?
+ """
# 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
@@ -531,6 +476,47 @@ class RoomMemberWorkerStore(EventsWorkerStore):
desc="get_local_users_in_room",
)
+ async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool:
+ """
+ Check whether a given local user is currently joined to the given room.
+
+ Returns:
+ A boolean indicating whether the user is currently joined to the room
+
+ Raises:
+ Exeption when called with a non-local user to this homeserver
+ """
+ if not self.hs.is_mine_id(user_id):
+ raise Exception(
+ "Cannot call 'check_local_user_in_room' on "
+ "non-local user %s" % (user_id,),
+ )
+
+ (
+ membership,
+ member_event_id,
+ ) = await self.get_local_current_membership_for_user_in_room(
+ user_id=user_id,
+ room_id=room_id,
+ )
+
+ return membership == Membership.JOIN
+
+ async def is_server_notice_room(self, room_id: str) -> bool:
+ """
+ Determines whether the given room is a 'Server Notices' room, used for
+ sending server notices to a user.
+
+ This is determined by seeing whether the server notices user is present
+ in the room.
+ """
+ if self._server_notices_mxid is None:
+ return False
+ is_server_notices_room = await self.check_local_user_in_room(
+ user_id=self._server_notices_mxid, room_id=room_id
+ )
+ return is_server_notices_room
+
async def get_local_current_membership_for_user_in_room(
self, user_id: str, room_id: str
) -> Tuple[Optional[str], Optional[str]]:
@@ -592,27 +578,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
# for rooms the server is participating in.
- if self._current_state_events_membership_up_to_date:
- sql = """
- SELECT room_id, e.instance_name, e.stream_ordering
- FROM current_state_events AS c
- INNER JOIN events AS e USING (room_id, event_id)
- WHERE
- c.type = 'm.room.member'
- AND c.state_key = ?
- AND c.membership = ?
- """
- else:
- sql = """
- SELECT room_id, e.instance_name, e.stream_ordering
- FROM current_state_events AS c
- INNER JOIN room_memberships AS m USING (room_id, event_id)
- INNER JOIN events AS e USING (room_id, event_id)
- WHERE
- c.type = 'm.room.member'
- AND c.state_key = ?
- AND m.membership = ?
- """
+ sql = """
+ SELECT room_id, e.instance_name, e.stream_ordering
+ FROM current_state_events AS c
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ c.type = 'm.room.member'
+ AND c.state_key = ?
+ AND c.membership = ?
+ """
txn.execute(sql, (user_id, Membership.JOIN))
return frozenset(
@@ -622,117 +596,124 @@ class RoomMemberWorkerStore(EventsWorkerStore):
for room_id, instance, stream_id in txn
)
- @cachedList(
- cached_method_name="get_rooms_for_user_with_stream_ordering",
- list_name="user_ids",
- )
- async def get_rooms_for_users_with_stream_ordering(
+ async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str]
- ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
- """A batched version of `get_rooms_for_user_with_stream_ordering`.
-
- Returns:
- Map from user_id to set of rooms that is currently in.
+ ) -> Set[str]:
+ """Given a list of users return the set that the server still share a
+ room with.
"""
+
+ if not user_ids:
+ return set()
+
return await self.db_pool.runInteraction(
- "get_rooms_for_users_with_stream_ordering",
- self._get_rooms_for_users_with_stream_ordering_txn,
+ "get_users_server_still_shares_room_with",
+ self.get_users_server_still_shares_room_with_txn,
user_ids,
)
- def _get_rooms_for_users_with_stream_ordering_txn(
- self, txn: LoggingTransaction, user_ids: Collection[str]
- ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
+ def get_users_server_still_shares_room_with_txn(
+ self,
+ txn: LoggingTransaction,
+ user_ids: Collection[str],
+ ) -> Set[str]:
+ if not user_ids:
+ return set()
+
+ sql = """
+ SELECT state_key FROM current_state_events
+ WHERE
+ type = 'm.room.member'
+ AND membership = 'join'
+ AND %s
+ GROUP BY state_key
+ """
clause, args = make_in_list_sql_clause(
- self.database_engine,
- "c.state_key",
- user_ids,
+ self.database_engine, "state_key", user_ids
)
- if self._current_state_events_membership_up_to_date:
- sql = f"""
- SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
- FROM current_state_events AS c
- INNER JOIN events AS e USING (room_id, event_id)
- WHERE
- c.type = 'm.room.member'
- AND c.membership = ?
- AND {clause}
- """
- else:
- sql = f"""
- SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
- FROM current_state_events AS c
- INNER JOIN room_memberships AS m USING (room_id, event_id)
- INNER JOIN events AS e USING (room_id, event_id)
- WHERE
- c.type = 'm.room.member'
- AND m.membership = ?
- AND {clause}
- """
+ txn.execute(sql % (clause,), args)
- txn.execute(sql, [Membership.JOIN] + args)
+ return {row[0] for row in txn}
- result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
- user_id: set() for user_id in user_ids
- }
- for user_id, room_id, instance, stream_id in txn:
- result[user_id].add(
- GetRoomsForUserWithStreamOrdering(
- room_id, PersistedEventPosition(instance, stream_id)
- )
- )
-
- return {user_id: frozenset(v) for user_id, v in result.items()}
+ @cached(max_entries=500000, iterable=True)
+ async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]:
+ """Returns a set of room_ids the user is currently joined to.
- async def get_users_server_still_shares_room_with(
- self, user_ids: Collection[str]
- ) -> Set[str]:
- """Given a list of users return the set that the server still share a
- room with.
+ If a remote user only returns rooms this server is currently
+ participating in.
"""
+ rooms = self.get_rooms_for_user_with_stream_ordering.cache.get_immediate(
+ (user_id,),
+ None,
+ update_metrics=False,
+ )
+ if rooms:
+ return frozenset(r.room_id for r in rooms)
- if not user_ids:
- return set()
-
- def _get_users_server_still_shares_room_with_txn(
- txn: LoggingTransaction,
- ) -> Set[str]:
- sql = """
- SELECT state_key FROM current_state_events
- WHERE
- type = 'm.room.member'
- AND membership = 'join'
- AND %s
- GROUP BY state_key
- """
+ room_ids = await self.db_pool.simple_select_onecol(
+ table="current_state_events",
+ keyvalues={
+ "type": EventTypes.Member,
+ "membership": Membership.JOIN,
+ "state_key": user_id,
+ },
+ retcol="room_id",
+ desc="get_rooms_for_user",
+ )
- clause, args = make_in_list_sql_clause(
- self.database_engine, "state_key", user_ids
- )
+ return frozenset(room_ids)
- txn.execute(sql % (clause,), args)
+ @cachedList(
+ cached_method_name="get_rooms_for_user",
+ list_name="user_ids",
+ )
+ async def _get_rooms_for_users(
+ self, user_ids: Collection[str]
+ ) -> Dict[str, FrozenSet[str]]:
+ """A batched version of `get_rooms_for_user`.
- return {row[0] for row in txn}
+ Returns:
+ Map from user_id to set of rooms that is currently in.
+ """
- return await self.db_pool.runInteraction(
- "get_users_server_still_shares_room_with",
- _get_users_server_still_shares_room_with_txn,
+ rows = await self.db_pool.simple_select_many_batch(
+ table="current_state_events",
+ column="state_key",
+ iterable=user_ids,
+ retcols=(
+ "state_key",
+ "room_id",
+ ),
+ keyvalues={
+ "type": EventTypes.Member,
+ "membership": Membership.JOIN,
+ },
+ desc="get_rooms_for_users",
)
- async def get_rooms_for_user(
- self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
- ) -> FrozenSet[str]:
- """Returns a set of room_ids the user is currently joined to.
+ user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
- If a remote user only returns rooms this server is currently
- participating in.
+ for row in rows:
+ user_rooms[row["state_key"]].add(row["room_id"])
+
+ return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
+
+ async def get_rooms_for_users(
+ self, user_ids: Collection[str]
+ ) -> Dict[str, FrozenSet[str]]:
+ """A batched wrapper around `_get_rooms_for_users`, to prevent locking
+ other calls to `get_rooms_for_user` for large user lists.
"""
- rooms = await self.get_rooms_for_user_with_stream_ordering(
- user_id, on_invalidate=on_invalidate
- )
- return frozenset(r.room_id for r in rooms)
+ all_user_rooms: Dict[str, FrozenSet[str]] = {}
+
+ # 250 users is pretty arbitrary but the data can be quite large if users
+ # are in many rooms.
+ for batch_user_ids in batch_iter(user_ids, 250):
+ all_user_rooms.update(await self._get_rooms_for_users(batch_user_ids))
+
+ return all_user_rooms
@cached(max_entries=10000)
async def does_pair_of_users_share_a_room(
@@ -764,7 +745,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# user and the set of other users, and then checking if there is any
# overlap.
sql = f"""
- SELECT b.state_key
+ SELECT DISTINCT b.state_key
FROM (
SELECT room_id FROM current_state_events
WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ?
@@ -773,7 +754,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
SELECT room_id, state_key FROM current_state_events
WHERE type = 'm.room.member' AND membership = 'join' AND {clause}
) AS b using (room_id)
- LIMIT 1
"""
txn.execute(sql, (user_id, *args))
@@ -832,144 +812,92 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return shared_room_ids or frozenset()
- async def get_joined_users_from_state(
- self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
- ) -> Dict[str, ProfileInfo]:
- state_group: Union[object, int] = state_entry.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- assert state_group is not None
- with Measure(self._clock, "get_joined_users_from_state"):
- return await self._get_joined_users_from_context(
- room_id, state_group, state, context=state_entry
- )
+ async def get_joined_user_ids_from_state(
+ self, room_id: str, state: StateMap[str]
+ ) -> Set[str]:
+ """
+ For a given set of state IDs, get a set of user IDs in the room.
- @cached(num_args=2, iterable=True, max_entries=100000)
- async def _get_joined_users_from_context(
- self,
- room_id: str,
- state_group: Union[object, int],
- current_state_ids: StateMap[str],
- event: Optional[EventBase] = None,
- context: Optional["_StateCacheEntry"] = None,
- ) -> Dict[str, ProfileInfo]:
- # We don't use `state_group`, it's there so that we can cache based
- # on it. However, it's important that it's never None, since two current_states
- # with a state_group of None are likely to be different.
- assert state_group is not None
+ This method checks the local event cache, before calling
+ `_get_user_ids_from_membership_event_ids` for any uncached events.
+ """
- users_in_room = {}
- member_event_ids = [
- e_id
- for key, e_id in current_state_ids.items()
- if key[0] == EventTypes.Member
- ]
-
- if context is not None:
- # If we have a context with a delta from a previous state group,
- # check if we also have the result from the previous group in cache.
- # If we do then we can reuse that result and simply update it with
- # any membership changes in `delta_ids`
- if context.prev_group and context.delta_ids:
- prev_res = self._get_joined_users_from_context.cache.get_immediate(
- (room_id, context.prev_group), None
- )
- if prev_res and isinstance(prev_res, dict):
- users_in_room = dict(prev_res)
- member_event_ids = [
- e_id
- for key, e_id in context.delta_ids.items()
- if key[0] == EventTypes.Member
- ]
- for etype, state_key in context.delta_ids:
- if etype == EventTypes.Member:
- users_in_room.pop(state_key, None)
-
- # We check if we have any of the member event ids in the event cache
- # before we ask the DB
-
- # We don't update the event cache hit ratio as it completely throws off
- # the hit ratio counts. After all, we don't populate the cache if we
- # miss it here
- event_map = await self._get_events_from_cache(
- member_event_ids, update_metrics=False
- )
+ with Measure(self._clock, "get_joined_user_ids_from_state"):
+ users_in_room = set()
+ member_event_ids = [
+ e_id for key, e_id in state.items() if key[0] == EventTypes.Member
+ ]
- missing_member_event_ids = []
- for event_id in member_event_ids:
- ev_entry = event_map.get(event_id)
- if ev_entry and not ev_entry.event.rejected_reason:
- if ev_entry.event.membership == Membership.JOIN:
- users_in_room[ev_entry.event.state_key] = ProfileInfo(
- display_name=ev_entry.event.content.get("displayname", None),
- avatar_url=ev_entry.event.content.get("avatar_url", None),
- )
- else:
- missing_member_event_ids.append(event_id)
+ # We check if we have any of the member event ids in the event cache
+ # before we ask the DB
- if missing_member_event_ids:
- event_to_memberships = await self._get_joined_profiles_from_event_ids(
- missing_member_event_ids
+ # We don't update the event cache hit ratio as it completely throws off
+ # the hit ratio counts. After all, we don't populate the cache if we
+ # miss it here
+ event_map = self._get_events_from_local_cache(
+ member_event_ids, update_metrics=False
)
- users_in_room.update(row for row in event_to_memberships.values() if row)
-
- if event is not None and event.type == EventTypes.Member:
- if event.membership == Membership.JOIN:
- if event.event_id in member_event_ids:
- users_in_room[event.state_key] = ProfileInfo(
- display_name=event.content.get("displayname", None),
- avatar_url=event.content.get("avatar_url", None),
+
+ missing_member_event_ids = []
+ for event_id in member_event_ids:
+ ev_entry = event_map.get(event_id)
+ if ev_entry and not ev_entry.event.rejected_reason:
+ if ev_entry.event.membership == Membership.JOIN:
+ users_in_room.add(ev_entry.event.state_key)
+ else:
+ missing_member_event_ids.append(event_id)
+
+ if missing_member_event_ids:
+ event_to_memberships = (
+ await self._get_user_ids_from_membership_event_ids(
+ missing_member_event_ids
)
+ )
+ users_in_room.update(
+ user_id for user_id in event_to_memberships.values() if user_id
+ )
- return users_in_room
+ return users_in_room
- @cached(max_entries=10000)
- def _get_joined_profile_from_event_id(
+ @cached(
+ max_entries=10000,
+ # This name matches the old function that has been replaced - the cache name
+ # is kept here to maintain backwards compatibility.
+ name="_get_joined_profile_from_event_id",
+ )
+ def _get_user_id_from_membership_event_id(
self, event_id: str
) -> Optional[Tuple[str, ProfileInfo]]:
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_joined_profile_from_event_id",
+ cached_method_name="_get_user_id_from_membership_event_id",
list_name="event_ids",
)
- async def _get_joined_profiles_from_event_ids(
+ async def _get_user_ids_from_membership_event_ids(
self, event_ids: Iterable[str]
- ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
+ ) -> Dict[str, Optional[str]]:
"""For given set of member event_ids check if they point to a join
- event and if so return the associated user and profile info.
+ event.
Args:
event_ids: The member event IDs to lookup
Returns:
- Map from event ID to `user_id` and ProfileInfo (or None if not join event).
+ Map from event ID to `user_id`, or None if event is not a join.
"""
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
- retcols=("user_id", "display_name", "avatar_url", "event_id"),
+ retcols=("user_id", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=1000,
- desc="_get_joined_profiles_from_event_ids",
+ desc="_get_user_ids_from_membership_event_ids",
)
- return {
- row["event_id"]: (
- row["user_id"],
- ProfileInfo(
- avatar_url=row["avatar_url"], display_name=row["display_name"]
- ),
- )
- for row in rows
- }
+ return {row["event_id"]: row["user_id"] for row in rows}
@cached(max_entries=10000)
async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -1051,6 +979,72 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_current_hosts_in_room", get_current_hosts_in_room_txn
)
+ @cached(iterable=True, max_entries=10000)
+ async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
+ """
+ Get current hosts in room based on current state.
+
+ The heuristic of sorting by servers who have been in the room the
+ longest is good because they're most likely to have anything we ask
+ about.
+
+ For SQLite the returned list is not ordered, as SQLite doesn't support
+ the appropriate SQL.
+
+ Uses `m.room.member`s in the room state at the current forward
+ extremities to determine which hosts are in the room.
+
+ Will return inaccurate results for rooms with partial state, since the
+ state for the forward extremities of those rooms will exclude most
+ members. We may also calculate room state incorrectly for such rooms and
+ believe that a host is or is not in the room when the opposite is true.
+
+ Returns:
+ Returns a list of servers sorted by longest in the room first. (aka.
+ sorted by join with the lowest depth first).
+ """
+
+ if isinstance(self.database_engine, Sqlite3Engine):
+ # If we're using SQLite then let's just always use
+ # `get_users_in_room` rather than funky SQL.
+
+ domains = await self.get_current_hosts_in_room(room_id)
+ return list(domains)
+
+ # For PostgreSQL we can use a regex to pull out the domains from the
+ # joined users in `current_state_events` via regex.
+
+ def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]:
+ # Returns a list of servers currently joined in the room sorted by
+ # longest in the room first (aka. with the lowest depth). The
+ # heuristic of sorting by servers who have been in the room the
+ # longest is good because they're most likely to have anything we
+ # ask about.
+ sql = """
+ SELECT
+ /* Match the domain part of the MXID */
+ substring(c.state_key FROM '@[^:]*:(.*)$') as server_domain
+ FROM current_state_events c
+ /* Get the depth of the event from the events table */
+ INNER JOIN events AS e USING (event_id)
+ WHERE
+ /* Find any join state events in the room */
+ c.type = 'm.room.member'
+ AND c.membership = 'join'
+ AND c.room_id = ?
+ /* Group all state events from the same domain into their own buckets (groups) */
+ GROUP BY server_domain
+ /* Sorted by lowest depth first */
+ ORDER BY min(e.depth) ASC;
+ """
+ txn.execute(sql, (room_id,))
+ # `server_domain` will be `NULL` for malformed MXIDs with no colons.
+ return [d for d, in txn if d is not None]
+
+ return await self.db_pool.runInteraction(
+ "get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
+ )
+
async def get_joined_hosts(
self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
@@ -1128,12 +1122,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
else:
# The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles.
- joined_users = await self.get_joined_users_from_state(
- room_id, state, state_entry
+ joined_user_ids = await self.get_joined_user_ids_from_state(
+ room_id, state
)
cache.hosts_to_joined_users = {}
- for user_id in joined_users:
+ for user_id in joined_user_ids:
host = intern_string(get_domain_from_id(user_id))
cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
@@ -1212,6 +1206,30 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
+ async def is_locally_forgotten_room(self, room_id: str) -> bool:
+ """Returns whether all local users have forgotten this room_id.
+
+ Args:
+ room_id: The room ID to query.
+
+ Returns:
+ Whether the room is forgotten.
+ """
+
+ sql = """
+ SELECT count(*) > 0 FROM local_current_membership
+ INNER JOIN room_memberships USING (room_id, event_id)
+ WHERE
+ room_id = ?
+ AND forgotten = 0;
+ """
+
+ rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+
+ # `count(*)` returns always an integer
+ # If any rows still exist it means someone has not forgotten this room yet
+ return not rows[0][0]
+
async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
"""Get all rooms that the user has ever been in.
@@ -1499,6 +1517,36 @@ class RoomMemberStore(
await self.db_pool.runInteraction("forget_membership", f)
+def extract_heroes_from_room_summary(
+ details: Mapping[str, MemberSummary], me: str
+) -> List[str]:
+ """Determine the users that represent a room, from the perspective of the `me` user.
+
+ The rules which say which users we select are specified in the "Room Summary"
+ section of
+ https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3sync
+
+ Returns a list (possibly empty) of heroes' mxids.
+ """
+ empty_ms = MemberSummary([], 0)
+
+ joined_user_ids = [
+ r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me
+ ]
+ invited_user_ids = [
+ r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me
+ ]
+ gone_user_ids = [
+ r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me
+ ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me]
+
+ # FIXME: order by stream ordering rather than as returned by SQL
+ if joined_user_ids or invited_user_ids:
+ return sorted(joined_user_ids + invited_user_ids)[0:5]
+ else:
+ return sorted(gone_user_ids)[0:5]
+
+
@attr.s(slots=True, auto_attribs=True)
class _JoinedHostsCache:
"""The cached data used by the `_get_joined_hosts_cache`."""
|