diff options
Diffstat (limited to 'synapse/storage/user_directory.py')
-rw-r--r-- | synapse/storage/user_directory.py | 212 |
1 files changed, 112 insertions, 100 deletions
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index a8781b0e5d..fea866c043 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -22,6 +22,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.state import StateFilter from synapse.types import get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -31,12 +32,19 @@ logger = logging.getLogger(__name__) class UserDirectoryStore(SQLBaseStore): - @cachedInlineCallbacks(cache_context=True) - def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context): + @defer.inlineCallbacks + def is_room_world_readable_or_publicly_joinable(self, room_id): """Check if the room is either world_readable or publically joinable """ - current_state_ids = yield self.get_current_state_ids( - room_id, on_invalidate=cache_context.invalidate + + # Create a state filter that only queries join and history state event + types_to_filter = ( + (EventTypes.JoinRules, ""), + (EventTypes.RoomHistoryVisibility, ""), + ) + + current_state_ids = yield self.get_filtered_current_state_ids( + room_id, StateFilter.from_types(types_to_filter) ) join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) @@ -66,14 +74,8 @@ class UserDirectoryStore(SQLBaseStore): """ yield self._simple_insert_many( table="users_in_public_rooms", - values=[ - { - "user_id": user_id, - "room_id": room_id, - } - for user_id in user_ids - ], - desc="add_users_to_public_room" + values=[{"user_id": user_id, "room_id": room_id} for user_id in user_ids], + desc="add_users_to_public_room", ) for user_id in user_ids: self.get_user_in_public_room.invalidate((user_id,)) @@ -99,7 +101,9 @@ class UserDirectoryStore(SQLBaseStore): """ args = ( ( - user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id), + user_id, + get_localpart_from_id(user_id), + get_domain_from_id(user_id), profile.display_name, ) for user_id, profile in iteritems(users_with_profile) @@ -112,7 +116,7 @@ class UserDirectoryStore(SQLBaseStore): args = ( ( user_id, - "%s %s" % (user_id, p.display_name,) if p.display_name else user_id + "%s %s" % (user_id, p.display_name) if p.display_name else user_id, ) for user_id, p in iteritems(users_with_profile) ) @@ -133,12 +137,10 @@ class UserDirectoryStore(SQLBaseStore): "avatar_url": profile.avatar_url, } for user_id, profile in iteritems(users_with_profile) - ] + ], ) for user_id in users_with_profile: - txn.call_after( - self.get_user_in_directory.invalidate, (user_id,) - ) + txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) return self.runInteraction( "add_profiles_to_user_dir", _add_profiles_to_user_dir_txn @@ -168,39 +170,69 @@ class UserDirectoryStore(SQLBaseStore): if isinstance(self.database_engine, PostgresEngine): # We weight the localpart most highly, then display name and finally # server name - if new_entry: + if self.database_engine.can_native_upsert: sql = """ INSERT INTO user_directory_search(user_id, vector) VALUES (?, setweight(to_tsvector('english', ?), 'A') || setweight(to_tsvector('english', ?), 'D') || setweight(to_tsvector('english', COALESCE(?, '')), 'B') - ) + ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector """ txn.execute( sql, ( - user_id, get_localpart_from_id(user_id), - get_domain_from_id(user_id), display_name, - ) + user_id, + get_localpart_from_id(user_id), + get_domain_from_id(user_id), + display_name, + ), ) else: - sql = """ - UPDATE user_directory_search - SET vector = setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') - WHERE user_id = ? - """ - txn.execute( - sql, - ( - get_localpart_from_id(user_id), get_domain_from_id(user_id), - display_name, user_id, + # TODO: Remove this code after we've bumped the minimum version + # of postgres to always support upserts, so we can get rid of + # `new_entry` usage + if new_entry is True: + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES (?, + setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + ) + """ + txn.execute( + sql, + ( + user_id, + get_localpart_from_id(user_id), + get_domain_from_id(user_id), + display_name, + ), + ) + elif new_entry is False: + sql = """ + UPDATE user_directory_search + SET vector = setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + WHERE user_id = ? + """ + txn.execute( + sql, + ( + get_localpart_from_id(user_id), + get_domain_from_id(user_id), + display_name, + user_id, + ), + ) + else: + raise RuntimeError( + "upsert returned None when 'can_native_upsert' is False" ) - ) elif isinstance(self.database_engine, Sqlite3Engine): - value = "%s %s" % (user_id, display_name,) if display_name else user_id + value = "%s %s" % (user_id, display_name) if display_name else user_id self._simple_upsert_txn( txn, table="user_directory_search", @@ -231,29 +263,18 @@ class UserDirectoryStore(SQLBaseStore): def remove_from_user_dir(self, user_id): def _remove_from_user_dir_txn(txn): self._simple_delete_txn( - txn, - table="user_directory", - keyvalues={"user_id": user_id}, + txn, table="user_directory", keyvalues={"user_id": user_id} ) self._simple_delete_txn( - txn, - table="user_directory_search", - keyvalues={"user_id": user_id}, + txn, table="user_directory_search", keyvalues={"user_id": user_id} ) self._simple_delete_txn( - txn, - table="users_in_public_rooms", - keyvalues={"user_id": user_id}, - ) - txn.call_after( - self.get_user_in_directory.invalidate, (user_id,) + txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} ) - txn.call_after( - self.get_user_in_public_room.invalidate, (user_id,) - ) - return self.runInteraction( - "remove_from_user_dir", _remove_from_user_dir_txn, - ) + txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) + txn.call_after(self.get_user_in_public_room.invalidate, (user_id,)) + + return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) @defer.inlineCallbacks def remove_from_user_in_public_room(self, user_id): @@ -338,6 +359,7 @@ class UserDirectoryStore(SQLBaseStore): share_private (bool): Is the room private user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. """ + def _add_users_who_share_room_txn(txn): self._simple_insert_many_txn( txn, @@ -354,13 +376,12 @@ class UserDirectoryStore(SQLBaseStore): ) for user_id, other_user_id in user_id_tuples: txn.call_after( - self.get_users_who_share_room_from_dir.invalidate, - (user_id,), + self.get_users_who_share_room_from_dir.invalidate, (user_id,) ) txn.call_after( - self.get_if_users_share_a_room.invalidate, - (user_id, other_user_id), + self.get_if_users_share_a_room.invalidate, (user_id, other_user_id) ) + return self.runInteraction( "add_users_who_share_room", _add_users_who_share_room_txn ) @@ -374,6 +395,7 @@ class UserDirectoryStore(SQLBaseStore): share_private (bool): Is the room private user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. """ + def _update_users_who_share_room_txn(txn): sql = """ UPDATE users_who_share_rooms @@ -381,21 +403,16 @@ class UserDirectoryStore(SQLBaseStore): WHERE user_id = ? AND other_user_id = ? """ txn.executemany( - sql, - ( - (room_id, share_private, uid, oid) - for uid, oid in user_id_sets - ) + sql, ((room_id, share_private, uid, oid) for uid, oid in user_id_sets) ) for user_id, other_user_id in user_id_sets: txn.call_after( - self.get_users_who_share_room_from_dir.invalidate, - (user_id,), + self.get_users_who_share_room_from_dir.invalidate, (user_id,) ) txn.call_after( - self.get_if_users_share_a_room.invalidate, - (user_id, other_user_id), + self.get_if_users_share_a_room.invalidate, (user_id, other_user_id) ) + return self.runInteraction( "update_users_who_share_room", _update_users_who_share_room_txn ) @@ -409,22 +426,18 @@ class UserDirectoryStore(SQLBaseStore): share_private (bool): Is the room private user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. """ + def _remove_user_who_share_room_txn(txn): self._simple_delete_txn( txn, table="users_who_share_rooms", - keyvalues={ - "user_id": user_id, - "other_user_id": other_user_id, - }, + keyvalues={"user_id": user_id, "other_user_id": other_user_id}, ) txn.call_after( - self.get_users_who_share_room_from_dir.invalidate, - (user_id,), + self.get_users_who_share_room_from_dir.invalidate, (user_id,) ) txn.call_after( - self.get_if_users_share_a_room.invalidate, - (user_id, other_user_id), + self.get_if_users_share_a_room.invalidate, (user_id, other_user_id) ) return self.runInteraction( @@ -445,10 +458,7 @@ class UserDirectoryStore(SQLBaseStore): """ return self._simple_select_one_onecol( table="users_who_share_rooms", - keyvalues={ - "user_id": user_id, - "other_user_id": other_user_id, - }, + keyvalues={"user_id": user_id, "other_user_id": other_user_id}, retcol="share_private", allow_none=True, desc="get_if_users_share_a_room", @@ -466,17 +476,12 @@ class UserDirectoryStore(SQLBaseStore): """ rows = yield self._simple_select_list( table="users_who_share_rooms", - keyvalues={ - "user_id": user_id, - }, - retcols=("other_user_id", "share_private",), + keyvalues={"user_id": user_id}, + retcols=("other_user_id", "share_private"), desc="get_users_who_share_room_with_user", ) - defer.returnValue({ - row["other_user_id"]: row["share_private"] - for row in rows - }) + defer.returnValue({row["other_user_id"]: row["share_private"] for row in rows}) def get_users_in_share_dir_with_room_id(self, user_id, room_id): """Get all user tuples that are in the users_who_share_rooms due to the @@ -523,6 +528,7 @@ class UserDirectoryStore(SQLBaseStore): def delete_all_from_user_dir(self): """Delete the entire user directory """ + def _delete_all_from_user_dir_txn(txn): txn.execute("DELETE FROM user_directory") txn.execute("DELETE FROM user_directory_search") @@ -532,6 +538,7 @@ class UserDirectoryStore(SQLBaseStore): txn.call_after(self.get_user_in_public_room.invalidate_all) txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all) txn.call_after(self.get_if_users_share_a_room.invalidate_all) + return self.runInteraction( "delete_all_from_user_dir", _delete_all_from_user_dir_txn ) @@ -541,7 +548,7 @@ class UserDirectoryStore(SQLBaseStore): return self._simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, - retcols=("room_id", "display_name", "avatar_url",), + retcols=("room_id", "display_name", "avatar_url"), allow_none=True, desc="get_user_in_directory", ) @@ -574,7 +581,9 @@ class UserDirectoryStore(SQLBaseStore): def get_current_state_deltas(self, prev_stream_id): prev_stream_id = int(prev_stream_id) - if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id): + if not self._curr_state_delta_stream_cache.has_any_entity_changed( + prev_stream_id + ): return [] def get_current_state_deltas_txn(txn): @@ -608,7 +617,7 @@ class UserDirectoryStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC """ - txn.execute(sql, (prev_stream_id, max_stream_id,)) + txn.execute(sql, (prev_stream_id, max_stream_id)) return self.cursor_to_dict(txn) return self.runInteraction( @@ -698,8 +707,11 @@ class UserDirectoryStore(SQLBaseStore): display_name IS NULL, avatar_url IS NULL LIMIT ? - """ % (join_clause, where_clause) - args = join_args + (full_query, exact_query, prefix_query, limit + 1,) + """ % ( + join_clause, + where_clause, + ) + args = join_args + (full_query, exact_query, prefix_query, limit + 1) elif isinstance(self.database_engine, Sqlite3Engine): search_query = _parse_query_sqlite(search_term) @@ -716,7 +728,10 @@ class UserDirectoryStore(SQLBaseStore): display_name IS NULL, avatar_url IS NULL LIMIT ? - """ % (join_clause, where_clause) + """ % ( + join_clause, + where_clause, + ) args = join_args + (search_query, limit + 1) else: # This should be unreachable. @@ -728,10 +743,7 @@ class UserDirectoryStore(SQLBaseStore): limited = len(results) > limit - defer.returnValue({ - "limited": limited, - "results": results, - }) + defer.returnValue({"limited": limited, "results": results}) def _parse_query_sqlite(search_term): @@ -746,7 +758,7 @@ def _parse_query_sqlite(search_term): # Pull out the individual words, discarding any non-word characters. results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - return " & ".join("(%s* OR %s)" % (result, result,) for result in results) + return " & ".join("(%s* OR %s)" % (result, result) for result in results) def _parse_query_postgres(search_term): @@ -759,7 +771,7 @@ def _parse_query_postgres(search_term): # Pull out the individual words, discarding any non-word characters. results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - both = " & ".join("(%s:* | %s)" % (result, result,) for result in results) + both = " & ".join("(%s:* | %s)" % (result, result) for result in results) exact = " & ".join("%s" % (result,) for result in results) prefix = " & ".join("%s:*" % (result,) for result in results) |