summary refs log tree commit diff
path: root/synapse/storage/databases/main/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/databases/main/roommember.py534
1 files changed, 505 insertions, 29 deletions
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py

index 1d9f0f52e1..7ca73abb83 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -19,6 +19,7 @@ # # import logging +from http import HTTPStatus from typing import ( TYPE_CHECKING, AbstractSet, @@ -39,6 +40,8 @@ from typing import ( import attr from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import Codes, SynapseError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.logging.opentracing import trace from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -50,13 +53,20 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.stream import _filter_results_by_stream from synapse.storage.engines import Sqlite3Engine -from synapse.storage.roommember import MemberSummary, ProfileInfo, RoomsForUser +from synapse.storage.roommember import ( + MemberSummary, + ProfileInfo, + RoomsForUser, + RoomsForUserSlidingSync, +) from synapse.types import ( JsonDict, PersistedEventPosition, StateMap, StrCollection, + StreamToken, get_domain_from_id, ) from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -71,6 +81,7 @@ logger = logging.getLogger(__name__) _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" +_POPULATE_PARTICIPANT_BG_UPDATE_BATCH_SIZE = 1000 @attr.s(frozen=True, slots=True, auto_attribs=True) @@ -225,9 +236,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): AND m.room_id = c.room_id AND m.user_id = c.state_key WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? AND %s - """ % ( - clause, - ) + """ % (clause,) txn.execute(sql, (room_id, Membership.JOIN, *ids)) return {r[0]: ProfileInfo(display_name=r[1], avatar_url=r[2]) for r in txn} @@ -306,18 +315,10 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): # We do this all in one transaction to keep the cache small. # FIXME: get rid of this when we have room_stats - # Note, rejected events will have a null membership field, so - # we we manually filter them out. - sql = """ - SELECT count(*), membership FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? - AND membership IS NOT NULL - GROUP BY membership - """ + counts = self._get_member_counts_txn(txn, room_id) - txn.execute(sql, (room_id,)) res: Dict[str, MemberSummary] = {} - for count, membership in txn: + for membership, count in counts.items(): res.setdefault(membership, MemberSummary([], count)) # Order by membership (joins -> invites -> leave (former insiders) -> @@ -364,6 +365,31 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) @cached() + async def get_member_counts(self, room_id: str) -> Mapping[str, int]: + """Get a mapping of number of users by membership""" + + return await self.db_pool.runInteraction( + "get_member_counts", self._get_member_counts_txn, room_id + ) + + def _get_member_counts_txn( + self, txn: LoggingTransaction, room_id: str + ) -> Dict[str, int]: + """Get a mapping of number of users by membership""" + + # Note, rejected events will have a null membership field, so + # we we manually filter them out. + sql = """ + SELECT count(*), membership FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? + AND membership IS NOT NULL + GROUP BY membership + """ + + txn.execute(sql, (room_id,)) + return {membership: count for count, membership in txn} + + @cached() async def get_number_joined_users_in_room(self, room_id: str) -> int: return await self.db_pool.simple_select_one_onecol( table="current_state_events", @@ -524,9 +550,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): WHERE user_id = ? AND %s - """ % ( - clause, - ) + """ % (clause,) txn.execute(sql, (user_id, *args)) results = [ @@ -631,10 +655,8 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ # Paranoia check. if not self.hs.is_mine_id(user_id): - raise Exception( - "Cannot call 'get_local_current_membership_for_user_in_room' on " - "non-local user %s" % (user_id,), - ) + message = f"Provided user_id {user_id} is a non-local user" + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.BAD_JSON) results = cast( Optional[Tuple[str, str]], @@ -692,6 +714,27 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): return {row[0] for row in txn} + async def get_rooms_user_currently_banned_from( + self, user_id: str + ) -> FrozenSet[str]: + """Returns a set of room_ids the user is currently banned from. + + If a remote user only returns rooms this server is currently + participating in. + """ + room_ids = await self.db_pool.simple_select_onecol( + table="current_state_events", + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.BAN, + "state_key": user_id, + }, + retcol="room_id", + desc="get_rooms_user_currently_banned_from", + ) + + return frozenset(room_ids) + @cached(max_entries=500000, iterable=True) async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]: """Returns a set of room_ids the user is currently joined to. @@ -808,7 +851,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ txn.execute(sql, (user_id, *args)) - return {u: True for u, in txn} + return {u: True for (u,) in txn} to_return = {} for batch_user_ids in batch_iter(other_user_ids, 1000): @@ -828,6 +871,73 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): return {u for u, share_room in user_dict.items() if share_room} + @cached(max_entries=10000) + async def does_pair_of_users_share_a_room_joined_or_invited( + self, user_id: str, other_user_id: str + ) -> bool: + raise NotImplementedError() + + @cachedList( + cached_method_name="does_pair_of_users_share_a_room_joined_or_invited", + list_name="other_user_ids", + ) + async def _do_users_share_a_room_joined_or_invited( + self, user_id: str, other_user_ids: Collection[str] + ) -> Mapping[str, Optional[bool]]: + """Return mapping from user ID to whether they share a room with the + given user via being either joined or invited. + + Note: `None` and `False` are equivalent and mean they don't share a + room. + """ + + def do_users_share_a_room_joined_or_invited_txn( + txn: LoggingTransaction, user_ids: Collection[str] + ) -> Dict[str, bool]: + clause, args = make_in_list_sql_clause( + self.database_engine, "state_key", user_ids + ) + + # This query works by fetching both the list of rooms for the target + # user and the set of other users, and then checking if there is any + # overlap. + sql = f""" + SELECT DISTINCT b.state_key + FROM ( + SELECT room_id FROM current_state_events + WHERE type = 'm.room.member' AND (membership = 'join' OR membership = 'invite') AND state_key = ? + ) AS a + INNER JOIN ( + SELECT room_id, state_key FROM current_state_events + WHERE type = 'm.room.member' AND (membership = 'join' OR membership = 'invite') AND {clause} + ) AS b using (room_id) + """ + + txn.execute(sql, (user_id, *args)) + return {u: True for (u,) in txn} + + to_return = {} + for batch_user_ids in batch_iter(other_user_ids, 1000): + res = await self.db_pool.runInteraction( + "do_users_share_a_room_joined_or_invited", + do_users_share_a_room_joined_or_invited_txn, + batch_user_ids, + ) + to_return.update(res) + + return to_return + + async def do_users_share_a_room_joined_or_invited( + self, user_id: str, other_user_ids: Collection[str] + ) -> Set[str]: + """Return the set of users who share a room with the first users via being either joined or invited""" + + user_dict = await self._do_users_share_a_room_joined_or_invited( + user_id, other_user_ids + ) + + return {u for u, share_room in user_dict.items() if share_room} + async def get_users_who_share_room_with_user(self, user_id: str) -> Set[str]: """Returns the set of users who share a room with `user_id`""" room_ids = await self.get_rooms_for_user(user_id) @@ -1026,7 +1136,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): AND room_id = ? """ txn.execute(sql, (room_id,)) - return {d for d, in txn} + return {d for (d,) in txn} return await self.db_pool.runInteraction( "get_current_hosts_in_room", get_current_hosts_in_room_txn @@ -1094,7 +1204,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ txn.execute(sql, (room_id,)) # `server_domain` will be `NULL` for malformed MXIDs with no colons. - return tuple(d for d, in txn if d is not None) + return tuple(d for (d,) in txn if d is not None) return await self.db_pool.runInteraction( "get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn @@ -1311,9 +1421,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): room_id = ? AND membership = ? AND NOT (%s) LIMIT 1 - """ % ( - clause, - ) + """ % (clause,) def _is_local_host_in_room_ignoring_users_txn( txn: LoggingTransaction, @@ -1337,11 +1445,23 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): keyvalues={"user_id": user_id, "room_id": room_id}, updatevalues={"forgotten": 1}, ) + # Handle updating the `sliding_sync_membership_snapshots` table + self.db_pool.simple_update_txn( + txn, + table="sliding_sync_membership_snapshots", + keyvalues={"user_id": user_id, "room_id": room_id}, + updatevalues={"forgotten": 1}, + ) self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) self._invalidate_cache_and_stream( txn, self.get_forgotten_rooms_for_user, (user_id,) ) + self._invalidate_cache_and_stream( + txn, + self.get_sliding_sync_rooms_for_user_from_membership_snapshots, + (user_id,), + ) await self.db_pool.runInteraction("forget_membership", f) @@ -1371,6 +1491,360 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): desc="room_forgetter_stream_pos", ) + @cached(iterable=True, max_entries=10000) + async def get_sliding_sync_rooms_for_user_from_membership_snapshots( + self, user_id: str + ) -> Mapping[str, RoomsForUserSlidingSync]: + """ + Get all the rooms for a user to handle a sliding sync request from the + `sliding_sync_membership_snapshots` table. These will be current memberships and + need to be rewound to the token range. + + Ignores forgotten rooms and rooms that the user has left themselves. + + Args: + user_id: The user ID to get the rooms for. + + Returns: + Map from room ID to membership info + """ + + def _txn( + txn: LoggingTransaction, + ) -> Dict[str, RoomsForUserSlidingSync]: + # XXX: If you use any new columns that can change (like from + # `sliding_sync_joined_rooms` or `forgotten`), make sure to bust the + # `get_sliding_sync_rooms_for_user_from_membership_snapshots` cache in the + # appropriate places (and add tests). + sql = """ + SELECT m.room_id, m.sender, m.membership, m.membership_event_id, + r.room_version, + m.event_instance_name, m.event_stream_ordering, + m.has_known_state, + COALESCE(j.room_type, m.room_type), + COALESCE(j.is_encrypted, m.is_encrypted) + FROM sliding_sync_membership_snapshots AS m + INNER JOIN rooms AS r USING (room_id) + LEFT JOIN sliding_sync_joined_rooms AS j ON (j.room_id = m.room_id AND m.membership = 'join') + WHERE user_id = ? + AND m.forgotten = 0 + AND (m.membership != 'leave' OR m.user_id != m.sender) + """ + txn.execute(sql, (user_id,)) + + return { + row[0]: RoomsForUserSlidingSync( + room_id=row[0], + sender=row[1], + membership=row[2], + event_id=row[3], + room_version_id=row[4], + event_pos=PersistedEventPosition(row[5], row[6]), + has_known_state=bool(row[7]), + room_type=row[8], + is_encrypted=bool(row[9]), + ) + for row in txn + # We filter out unknown room versions proactively. They + # shouldn't go down sync and their metadata may be in a broken + # state (causing errors). + if row[4] in KNOWN_ROOM_VERSIONS + } + + return await self.db_pool.runInteraction( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", + _txn, + ) + + async def get_sliding_sync_self_leave_rooms_after_to_token( + self, + user_id: str, + to_token: StreamToken, + ) -> Dict[str, RoomsForUserSlidingSync]: + """ + Get all the self-leave rooms for a user after the `to_token` (outside the token + range) that are potentially relevant[1] and needed to handle a sliding sync + request. The results are from the `sliding_sync_membership_snapshots` table and + will be current memberships and need to be rewound to the token range. + + [1] If a leave happens after the token range, we may have still been joined (or + any non-self-leave which is relevant to sync) to the room before so we need to + include it in the list of potentially relevant rooms and apply + our rewind logic (outside of this function) to see if it's actually relevant. + + This is basically a sister-function to + `get_sliding_sync_rooms_for_user_from_membership_snapshots`. We could + alternatively incorporate this logic into + `get_sliding_sync_rooms_for_user_from_membership_snapshots` but those results + are cached and the `to_token` isn't very cache friendly (people are constantly + requesting with new tokens) so we separate it out here. + + Args: + user_id: The user ID to get the rooms for. + to_token: Any self-leave memberships after this position will be returned. + + Returns: + Map from room ID to membership info + """ + # TODO: Potential to check + # `self._membership_stream_cache.has_entity_changed(...)` as an early-return + # shortcut. + + def _txn( + txn: LoggingTransaction, + ) -> Dict[str, RoomsForUserSlidingSync]: + sql = """ + SELECT m.room_id, m.sender, m.membership, m.membership_event_id, + r.room_version, + m.event_instance_name, m.event_stream_ordering, + m.has_known_state, + m.room_type, + m.is_encrypted + FROM sliding_sync_membership_snapshots AS m + INNER JOIN rooms AS r USING (room_id) + WHERE user_id = ? + AND m.forgotten = 0 + AND m.membership = 'leave' + AND m.user_id = m.sender + AND (m.event_stream_ordering > ?) + """ + # If a leave happens after the token range, we may have still been joined + # (or any non-self-leave which is relevant to sync) to the room before so we + # need to include it in the list of potentially relevant rooms and apply our + # rewind logic (outside of this function). + # + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_to_token_position = to_token.room_key.stream + txn.execute(sql, (user_id, min_to_token_position)) + + # Map from room_id to membership info + room_membership_for_user_map: Dict[str, RoomsForUserSlidingSync] = {} + for row in txn: + room_for_user = RoomsForUserSlidingSync( + room_id=row[0], + sender=row[1], + membership=row[2], + event_id=row[3], + room_version_id=row[4], + event_pos=PersistedEventPosition(row[5], row[6]), + has_known_state=bool(row[7]), + room_type=row[8], + is_encrypted=bool(row[9]), + ) + + # We filter out unknown room versions proactively. They shouldn't go + # down sync and their metadata may be in a broken state (causing + # errors). + if row[4] not in KNOWN_ROOM_VERSIONS: + continue + + # We only want to include the self-leave membership if it happened after + # the token range. + # + # Since the database pulls out more than necessary, we need to filter it + # down here. + if _filter_results_by_stream( + lower_token=None, + upper_token=to_token.room_key, + instance_name=room_for_user.event_pos.instance_name, + stream_ordering=room_for_user.event_pos.stream, + ): + continue + + room_membership_for_user_map[room_for_user.room_id] = room_for_user + + return room_membership_for_user_map + + return await self.db_pool.runInteraction( + "get_sliding_sync_self_leave_rooms_after_to_token", + _txn, + ) + + async def get_sliding_sync_room_for_user( + self, user_id: str, room_id: str + ) -> Optional[RoomsForUserSlidingSync]: + """Get the sliding sync room entry for the given user and room.""" + + def get_sliding_sync_room_for_user_txn( + txn: LoggingTransaction, + ) -> Optional[RoomsForUserSlidingSync]: + sql = """ + SELECT m.room_id, m.sender, m.membership, m.membership_event_id, + r.room_version, + m.event_instance_name, m.event_stream_ordering, + m.has_known_state, + COALESCE(j.room_type, m.room_type), + COALESCE(j.is_encrypted, m.is_encrypted) + FROM sliding_sync_membership_snapshots AS m + INNER JOIN rooms AS r USING (room_id) + LEFT JOIN sliding_sync_joined_rooms AS j ON (j.room_id = m.room_id AND m.membership = 'join') + WHERE user_id = ? + AND m.forgotten = 0 + AND m.room_id = ? + """ + txn.execute(sql, (user_id, room_id)) + row = txn.fetchone() + if not row: + return None + + return RoomsForUserSlidingSync( + room_id=row[0], + sender=row[1], + membership=row[2], + event_id=row[3], + room_version_id=row[4], + event_pos=PersistedEventPosition(row[5], row[6]), + has_known_state=bool(row[7]), + room_type=row[8], + is_encrypted=row[9], + ) + + return await self.db_pool.runInteraction( + "get_sliding_sync_room_for_user", get_sliding_sync_room_for_user_txn + ) + + async def get_sliding_sync_room_for_user_batch( + self, user_id: str, room_ids: StrCollection + ) -> Dict[str, RoomsForUserSlidingSync]: + """Get the sliding sync room entry for the given user and rooms.""" + + if not room_ids: + return {} + + def get_sliding_sync_room_for_user_batch_txn( + txn: LoggingTransaction, + ) -> Dict[str, RoomsForUserSlidingSync]: + clause, args = make_in_list_sql_clause( + self.database_engine, "m.room_id", room_ids + ) + sql = f""" + SELECT m.room_id, m.sender, m.membership, m.membership_event_id, + r.room_version, + m.event_instance_name, m.event_stream_ordering, + m.has_known_state, + COALESCE(j.room_type, m.room_type), + COALESCE(j.is_encrypted, m.is_encrypted) + FROM sliding_sync_membership_snapshots AS m + INNER JOIN rooms AS r USING (room_id) + LEFT JOIN sliding_sync_joined_rooms AS j ON (j.room_id = m.room_id AND m.membership = 'join') + WHERE m.forgotten = 0 + AND {clause} + AND user_id = ? + """ + args.append(user_id) + txn.execute(sql, args) + + return { + row[0]: RoomsForUserSlidingSync( + room_id=row[0], + sender=row[1], + membership=row[2], + event_id=row[3], + room_version_id=row[4], + event_pos=PersistedEventPosition(row[5], row[6]), + has_known_state=bool(row[7]), + room_type=row[8], + is_encrypted=row[9], + ) + for row in txn + } + + return await self.db_pool.runInteraction( + "get_sliding_sync_room_for_user_batch", + get_sliding_sync_room_for_user_batch_txn, + ) + + async def get_rooms_for_user_by_date( + self, user_id: str, from_ts: int + ) -> FrozenSet[str]: + """ + Fetch a list of rooms that the user has joined at or after the given timestamp, including + those they subsequently have left/been banned from. + + Args: + user_id: user ID of the user to search for + from_ts: a timestamp in ms from the unix epoch at which to begin the search at + """ + + def _get_rooms_for_user_by_join_date_txn( + txn: LoggingTransaction, user_id: str, timestamp: int + ) -> frozenset: + sql = """ + SELECT rm.room_id + FROM room_memberships AS rm + INNER JOIN events AS e USING (event_id) + WHERE rm.user_id = ? + AND rm.membership = 'join' + AND e.type = 'm.room.member' + AND e.received_ts >= ? + """ + txn.execute(sql, (user_id, timestamp)) + return frozenset([r[0] for r in txn]) + + return await self.db_pool.runInteraction( + "_get_rooms_for_user_by_join_date_txn", + _get_rooms_for_user_by_join_date_txn, + user_id, + from_ts, + ) + + async def set_room_participation(self, user_id: str, room_id: str) -> None: + """ + Record the provided user as participating in the given room + + Args: + user_id: the user ID of the user + room_id: ID of the room to set the participant in + """ + + def _set_room_participation_txn( + txn: LoggingTransaction, user_id: str, room_id: str + ) -> None: + sql = """ + UPDATE room_memberships + SET participant = true + WHERE event_id IN ( + SELECT event_id FROM local_current_membership + WHERE user_id = ? AND room_id = ? + ) + AND NOT participant + """ + txn.execute(sql, (user_id, room_id)) + + await self.db_pool.runInteraction( + "_set_room_participation_txn", _set_room_participation_txn, user_id, room_id + ) + + async def get_room_participation(self, user_id: str, room_id: str) -> bool: + """ + Check whether a user is listed as a participant in a room + + Args: + user_id: user ID of the user + room_id: ID of the room to check in + """ + + def _get_room_participation_txn( + txn: LoggingTransaction, user_id: str, room_id: str + ) -> bool: + sql = """ + SELECT participant + FROM local_current_membership AS l + INNER JOIN room_memberships AS r USING (event_id) + WHERE l.user_id = ? + AND l.room_id = ? + """ + txn.execute(sql, (user_id, room_id)) + res = txn.fetchone() + if res: + return res[0] + return False + + return await self.db_pool.runInteraction( + "_get_room_participation_txn", _get_room_participation_txn, user_id, room_id + ) + class RoomMemberBackgroundUpdateStore(SQLBaseStore): def __init__( @@ -1405,10 +1879,12 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): self, progress: JsonDict, batch_size: int ) -> int: target_min_stream_id = progress.get( - "target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined] + "target_min_stream_id_inclusive", + self._min_stream_order_on_start, # type: ignore[attr-defined] ) max_stream_id = progress.get( - "max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined] + "max_stream_id_exclusive", + self._stream_order_on_start + 1, # type: ignore[attr-defined] ) def add_membership_profile_txn(txn: LoggingTransaction) -> int: