diff options
Diffstat (limited to 'synapse/storage/roommember.py')
-rw-r--r-- | synapse/storage/roommember.py | 238 |
1 files changed, 162 insertions, 76 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 257bcdb2f8..eecb276465 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -108,7 +108,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): room_id, on_invalidate=cache_context.invalidate ) hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) - defer.returnValue(hosts) + return hosts @cached(max_entries=100000, iterable=True) def get_users_in_room(self, room_id): @@ -156,9 +156,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): # then we can avoid a join, which is a Very Good Thing given how # frequently this function gets called. if self._current_state_events_membership_up_to_date: + # Note, rejected events will have a null membership field, so + # we we manually filter them out. sql = """ SELECT count(*), membership FROM current_state_events WHERE type = 'm.room.member' AND room_id = ? + AND membership IS NOT NULL GROUP BY membership """ else: @@ -179,19 +182,30 @@ class RoomMemberWorkerStore(EventsWorkerStore): # we order by membership and then fairly arbitrarily by event_id so # heroes are consistent - sql = """ - SELECT m.user_id, m.membership, m.event_id - FROM room_memberships as m - INNER JOIN current_state_events as c - ON m.event_id = c.event_id - AND m.room_id = c.room_id - AND m.user_id = c.state_key - WHERE c.type = 'm.room.member' AND c.room_id = ? - ORDER BY - CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, - m.event_id ASC - LIMIT ? - """ + if self._current_state_events_membership_up_to_date: + # Note, rejected events will have a null membership field, so + # we we manually filter them out. + sql = """ + SELECT state_key, membership, event_id + FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? + AND membership IS NOT NULL + ORDER BY + CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, + event_id ASC + LIMIT ? + """ + else: + sql = """ + SELECT c.state_key, m.membership, c.event_id + FROM room_memberships as m + INNER JOIN current_state_events as c USING (room_id, event_id) + WHERE c.type = 'm.room.member' AND c.room_id = ? + ORDER BY + CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, + c.event_id ASC + LIMIT ? + """ # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) @@ -253,31 +267,38 @@ class RoomMemberWorkerStore(EventsWorkerStore): invites = yield self.get_invited_rooms_for_user(user_id) for invite in invites: if invite.room_id == room_id: - defer.returnValue(invite) - defer.returnValue(None) + return invite + return None + @defer.inlineCallbacks def get_rooms_for_user_where_membership_is(self, user_id, membership_list): """ Get all the rooms for this user where the membership for this user matches one in the membership list. + Filters out forgotten rooms. + Args: user_id (str): The user ID. membership_list (list): A list of synapse.api.constants.Membership values which the user must be in. + Returns: - A list of dictionary objects, with room_id, membership and sender - defined. + Deferred[list[RoomsForUser]] """ if not membership_list: return defer.succeed(None) - return self.runInteraction( + rooms = yield self.runInteraction( "get_rooms_for_user_where_membership_is", self._get_rooms_for_user_where_membership_is_txn, user_id, membership_list, ) + # Now we filter out forgotten rooms + forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) + return [room for room in rooms if room.room_id not in forgotten_rooms] + def _get_rooms_for_user_where_membership_is_txn( self, txn, user_id, membership_list ): @@ -287,26 +308,33 @@ class RoomMemberWorkerStore(EventsWorkerStore): results = [] if membership_list: - where_clause = "user_id = ? AND (%s) AND forgotten = 0" % ( - " OR ".join(["m.membership = ?" for _ in membership_list]), - ) - - args = [user_id] - args.extend(membership_list) + if self._current_state_events_membership_up_to_date: + sql = """ + SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering + FROM current_state_events AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND state_key = ? + AND c.membership IN (%s) + """ % ( + ",".join("?" * len(membership_list)) + ) + else: + sql = """ + SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering + FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (room_id, event_id) + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND state_key = ? + AND m.membership IN (%s) + """ % ( + ",".join("?" * len(membership_list)) + ) - sql = ( - "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering" - " FROM current_state_events as c" - " INNER JOIN room_memberships as m" - " ON m.event_id = c.event_id" - " INNER JOIN events as e" - " ON e.event_id = c.event_id" - " AND m.room_id = c.room_id" - " AND m.user_id = c.state_key" - " WHERE c.type = 'm.room.member' AND %s" - ) % (where_clause,) - - txn.execute(sql, args) + txn.execute(sql, (user_id, *membership_list)) results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] if do_invite: @@ -347,11 +375,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): rooms = yield self.get_rooms_for_user_where_membership_is( user_id, membership_list=[Membership.JOIN] ) - defer.returnValue( - frozenset( - GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) - for r in rooms - ) + return frozenset( + GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) + for r in rooms ) @defer.inlineCallbacks @@ -361,7 +387,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): rooms = yield self.get_rooms_for_user_with_stream_ordering( user_id, on_invalidate=on_invalidate ) - defer.returnValue(frozenset(r.room_id for r in rooms)) + return frozenset(r.room_id for r in rooms) @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) def get_users_who_share_room_with_user(self, user_id, cache_context): @@ -378,7 +404,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) user_who_share_room.update(user_ids) - defer.returnValue(user_who_share_room) + return user_who_share_room @defer.inlineCallbacks def get_joined_users_from_context(self, event, context): @@ -394,7 +420,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): result = yield self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) - defer.returnValue(result) + return result def get_joined_users_from_state(self, room_id, state_entry): state_group = state_entry.state_group @@ -508,7 +534,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): avatar_url=to_ascii(event.content.get("avatar_url", None)), ) - defer.returnValue(users_in_room) + return users_in_room @cachedInlineCallbacks(max_entries=10000) def is_host_joined(self, room_id, host): @@ -533,14 +559,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause) if not rows: - defer.returnValue(False) + return False user_id = rows[0][0] if get_domain_from_id(user_id) != host: # This can only happen if the host name has something funky in it raise Exception("Invalid host name") - defer.returnValue(True) + return True @cachedInlineCallbacks() def was_host_joined(self, room_id, host): @@ -573,14 +599,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause) if not rows: - defer.returnValue(False) + return False user_id = rows[0][0] if get_domain_from_id(user_id) != host: # This can only happen if the host name has something funky in it raise Exception("Invalid host name") - defer.returnValue(True) + return True def get_joined_hosts(self, room_id, state_entry): state_group = state_entry.state_group @@ -607,7 +633,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): cache = self._get_joined_hosts_cache(room_id) joined_hosts = yield cache.get_destinations(state_entry) - defer.returnValue(joined_hosts) + return joined_hosts @cached(max_entries=10000) def _get_joined_hosts_cache(self, room_id): @@ -637,7 +663,45 @@ class RoomMemberWorkerStore(EventsWorkerStore): return rows[0][0] count = yield self.runInteraction("did_forget_membership", f) - defer.returnValue(count == 0) + return count == 0 + + @cached() + def get_forgotten_rooms_for_user(self, user_id): + """Gets all rooms the user has forgotten. + + Args: + user_id (str) + + Returns: + Deferred[set[str]] + """ + + def _get_forgotten_rooms_for_user_txn(txn): + # This is a slightly convoluted query that first looks up all rooms + # that the user has forgotten in the past, then rechecks that list + # to see if any have subsequently been updated. This is done so that + # we can use a partial index on `forgotten = 1` on the assumption + # that few users will actually forget many rooms. + # + # Note that a room is considered "forgotten" if *all* membership + # events for that user and room have the forgotten field set (as + # when a user forgets a room we update all rows for that user and + # room, not just the current one). + sql = """ + SELECT room_id, ( + SELECT count(*) FROM room_memberships + WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0 + ) AS count + FROM room_memberships AS m + WHERE user_id = ? AND forgotten = 1 + GROUP BY room_id, user_id; + """ + txn.execute(sql, (user_id,)) + return set(row[0] for row in txn if row[1] == 0) + + return self.runInteraction( + "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn + ) @defer.inlineCallbacks def get_rooms_user_has_been_in(self, user_id): @@ -670,6 +734,13 @@ class RoomMemberStore(RoomMemberWorkerStore): _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, self._background_current_state_membership, ) + self.register_background_index_update( + "room_membership_forgotten_idx", + index_name="room_memberships_user_room_forgotten", + table="room_memberships", + columns=["user_id", "room_id"], + where_clause="forgotten = 1", + ) def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. @@ -771,6 +842,9 @@ class RoomMemberStore(RoomMemberWorkerStore): txn.execute(sql, (user_id, room_id)) self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) + self._invalidate_cache_and_stream( + txn, self.get_forgotten_rooms_for_user, (user_id,) + ) return self.runInteraction("forget_membership", f) @@ -847,53 +921,65 @@ class RoomMemberStore(RoomMemberWorkerStore): if not result: yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) - defer.returnValue(result) + return result @defer.inlineCallbacks def _background_current_state_membership(self, progress, batch_size): """Update the new membership column on current_state_events. + + This works by iterating over all rooms in alphebetical order. """ - if "rooms" not in progress: - rooms = yield self._simple_select_onecol( - table="current_state_events", - keyvalues={}, - retcol="DISTINCT room_id", - desc="_background_current_state_membership_get_rooms", - ) - progress["rooms"] = rooms + def _background_current_state_membership_txn(txn, last_processed_room): + processed = 0 + while processed < batch_size: + txn.execute( + """ + SELECT MIN(room_id) FROM current_state_events WHERE room_id > ? + """, + (last_processed_room,), + ) + row = txn.fetchone() + if not row or not row[0]: + return processed, True - rooms = progress["rooms"] + next_room, = row - def _background_current_state_membership_txn(txn): - processed = 0 - while rooms and processed < batch_size: sql = """ - UPDATE current_state_events AS c + UPDATE current_state_events SET membership = ( SELECT membership FROM room_memberships - WHERE event_id = c.event_id + WHERE event_id = current_state_events.event_id ) WHERE room_id = ? """ - txn.execute(sql, (rooms.pop(),)) + txn.execute(sql, (next_room,)) processed += txn.rowcount + last_processed_room = next_room + self._background_update_progress_txn( - txn, _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, progress + txn, + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, + {"last_processed_room": last_processed_room}, ) - return processed + return processed, False - result = yield self.runInteraction( + # If we haven't got a last processed room then just use the empty + # string, which will compare before all room IDs correctly. + last_processed_room = progress.get("last_processed_room", "") + + row_count, finished = yield self.runInteraction( "_background_current_state_membership_update", _background_current_state_membership_txn, + last_processed_room, ) - if not rooms: + if finished: yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME) - defer.returnValue(result) + return row_count class _JoinedHostsCache(object): @@ -921,7 +1007,7 @@ class _JoinedHostsCache(object): state_entry(synapse.state._StateCacheEntry) """ if state_entry.state_group == self.state_group: - defer.returnValue(frozenset(self.hosts_to_joined_users)) + return frozenset(self.hosts_to_joined_users) with (yield self.linearizer.queue(())): if state_entry.state_group == self.state_group: @@ -958,7 +1044,7 @@ class _JoinedHostsCache(object): else: self.state_group = object() self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users)) - defer.returnValue(frozenset(self.hosts_to_joined_users)) + return frozenset(self.hosts_to_joined_users) def __len__(self): return self._len |