summary refs log tree commit diff
path: root/synapse/storage/user_directory.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/user_directory.py')
-rw-r--r--synapse/storage/user_directory.py212
1 files changed, 112 insertions, 100 deletions
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index a8781b0e5d..fea866c043 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -22,6 +22,7 @@ from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.state import StateFilter
 from synapse.types import get_domain_from_id, get_localpart_from_id
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
@@ -31,12 +32,19 @@ logger = logging.getLogger(__name__)
 
 
 class UserDirectoryStore(SQLBaseStore):
-    @cachedInlineCallbacks(cache_context=True)
-    def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context):
+    @defer.inlineCallbacks
+    def is_room_world_readable_or_publicly_joinable(self, room_id):
         """Check if the room is either world_readable or publically joinable
         """
-        current_state_ids = yield self.get_current_state_ids(
-            room_id, on_invalidate=cache_context.invalidate
+
+        # Create a state filter that only queries join and history state event
+        types_to_filter = (
+            (EventTypes.JoinRules, ""),
+            (EventTypes.RoomHistoryVisibility, ""),
+        )
+
+        current_state_ids = yield self.get_filtered_current_state_ids(
+            room_id, StateFilter.from_types(types_to_filter)
         )
 
         join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
@@ -66,14 +74,8 @@ class UserDirectoryStore(SQLBaseStore):
         """
         yield self._simple_insert_many(
             table="users_in_public_rooms",
-            values=[
-                {
-                    "user_id": user_id,
-                    "room_id": room_id,
-                }
-                for user_id in user_ids
-            ],
-            desc="add_users_to_public_room"
+            values=[{"user_id": user_id, "room_id": room_id} for user_id in user_ids],
+            desc="add_users_to_public_room",
         )
         for user_id in user_ids:
             self.get_user_in_public_room.invalidate((user_id,))
@@ -99,7 +101,9 @@ class UserDirectoryStore(SQLBaseStore):
             """
             args = (
                 (
-                    user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id),
+                    user_id,
+                    get_localpart_from_id(user_id),
+                    get_domain_from_id(user_id),
                     profile.display_name,
                 )
                 for user_id, profile in iteritems(users_with_profile)
@@ -112,7 +116,7 @@ class UserDirectoryStore(SQLBaseStore):
             args = (
                 (
                     user_id,
-                    "%s %s" % (user_id, p.display_name,) if p.display_name else user_id
+                    "%s %s" % (user_id, p.display_name) if p.display_name else user_id,
                 )
                 for user_id, p in iteritems(users_with_profile)
             )
@@ -133,12 +137,10 @@ class UserDirectoryStore(SQLBaseStore):
                         "avatar_url": profile.avatar_url,
                     }
                     for user_id, profile in iteritems(users_with_profile)
-                ]
+                ],
             )
             for user_id in users_with_profile:
-                txn.call_after(
-                    self.get_user_in_directory.invalidate, (user_id,)
-                )
+                txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
 
         return self.runInteraction(
             "add_profiles_to_user_dir", _add_profiles_to_user_dir_txn
@@ -168,39 +170,69 @@ class UserDirectoryStore(SQLBaseStore):
             if isinstance(self.database_engine, PostgresEngine):
                 # We weight the localpart most highly, then display name and finally
                 # server name
-                if new_entry:
+                if self.database_engine.can_native_upsert:
                     sql = """
                         INSERT INTO user_directory_search(user_id, vector)
                         VALUES (?,
                             setweight(to_tsvector('english', ?), 'A')
                             || setweight(to_tsvector('english', ?), 'D')
                             || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
-                        )
+                        ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
                     """
                     txn.execute(
                         sql,
                         (
-                            user_id, get_localpart_from_id(user_id),
-                            get_domain_from_id(user_id), display_name,
-                        )
+                            user_id,
+                            get_localpart_from_id(user_id),
+                            get_domain_from_id(user_id),
+                            display_name,
+                        ),
                     )
                 else:
-                    sql = """
-                        UPDATE user_directory_search
-                        SET vector = setweight(to_tsvector('english', ?), 'A')
-                            || setweight(to_tsvector('english', ?), 'D')
-                            || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
-                        WHERE user_id = ?
-                    """
-                    txn.execute(
-                        sql,
-                        (
-                            get_localpart_from_id(user_id), get_domain_from_id(user_id),
-                            display_name, user_id,
+                    # TODO: Remove this code after we've bumped the minimum version
+                    # of postgres to always support upserts, so we can get rid of
+                    # `new_entry` usage
+                    if new_entry is True:
+                        sql = """
+                            INSERT INTO user_directory_search(user_id, vector)
+                            VALUES (?,
+                                setweight(to_tsvector('english', ?), 'A')
+                                || setweight(to_tsvector('english', ?), 'D')
+                                || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+                            )
+                        """
+                        txn.execute(
+                            sql,
+                            (
+                                user_id,
+                                get_localpart_from_id(user_id),
+                                get_domain_from_id(user_id),
+                                display_name,
+                            ),
+                        )
+                    elif new_entry is False:
+                        sql = """
+                            UPDATE user_directory_search
+                            SET vector = setweight(to_tsvector('english', ?), 'A')
+                                || setweight(to_tsvector('english', ?), 'D')
+                                || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+                            WHERE user_id = ?
+                        """
+                        txn.execute(
+                            sql,
+                            (
+                                get_localpart_from_id(user_id),
+                                get_domain_from_id(user_id),
+                                display_name,
+                                user_id,
+                            ),
+                        )
+                    else:
+                        raise RuntimeError(
+                            "upsert returned None when 'can_native_upsert' is False"
                         )
-                    )
             elif isinstance(self.database_engine, Sqlite3Engine):
-                value = "%s %s" % (user_id, display_name,) if display_name else user_id
+                value = "%s %s" % (user_id, display_name) if display_name else user_id
                 self._simple_upsert_txn(
                     txn,
                     table="user_directory_search",
@@ -231,29 +263,18 @@ class UserDirectoryStore(SQLBaseStore):
     def remove_from_user_dir(self, user_id):
         def _remove_from_user_dir_txn(txn):
             self._simple_delete_txn(
-                txn,
-                table="user_directory",
-                keyvalues={"user_id": user_id},
+                txn, table="user_directory", keyvalues={"user_id": user_id}
             )
             self._simple_delete_txn(
-                txn,
-                table="user_directory_search",
-                keyvalues={"user_id": user_id},
+                txn, table="user_directory_search", keyvalues={"user_id": user_id}
             )
             self._simple_delete_txn(
-                txn,
-                table="users_in_public_rooms",
-                keyvalues={"user_id": user_id},
-            )
-            txn.call_after(
-                self.get_user_in_directory.invalidate, (user_id,)
+                txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
             )
-            txn.call_after(
-                self.get_user_in_public_room.invalidate, (user_id,)
-            )
-        return self.runInteraction(
-            "remove_from_user_dir", _remove_from_user_dir_txn,
-        )
+            txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
+            txn.call_after(self.get_user_in_public_room.invalidate, (user_id,))
+
+        return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
 
     @defer.inlineCallbacks
     def remove_from_user_in_public_room(self, user_id):
@@ -338,6 +359,7 @@ class UserDirectoryStore(SQLBaseStore):
             share_private (bool): Is the room private
             user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
         """
+
         def _add_users_who_share_room_txn(txn):
             self._simple_insert_many_txn(
                 txn,
@@ -354,13 +376,12 @@ class UserDirectoryStore(SQLBaseStore):
             )
             for user_id, other_user_id in user_id_tuples:
                 txn.call_after(
-                    self.get_users_who_share_room_from_dir.invalidate,
-                    (user_id,),
+                    self.get_users_who_share_room_from_dir.invalidate, (user_id,)
                 )
                 txn.call_after(
-                    self.get_if_users_share_a_room.invalidate,
-                    (user_id, other_user_id),
+                    self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
                 )
+
         return self.runInteraction(
             "add_users_who_share_room", _add_users_who_share_room_txn
         )
@@ -374,6 +395,7 @@ class UserDirectoryStore(SQLBaseStore):
             share_private (bool): Is the room private
             user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
         """
+
         def _update_users_who_share_room_txn(txn):
             sql = """
                 UPDATE users_who_share_rooms
@@ -381,21 +403,16 @@ class UserDirectoryStore(SQLBaseStore):
                 WHERE user_id = ? AND other_user_id = ?
             """
             txn.executemany(
-                sql,
-                (
-                    (room_id, share_private, uid, oid)
-                    for uid, oid in user_id_sets
-                )
+                sql, ((room_id, share_private, uid, oid) for uid, oid in user_id_sets)
             )
             for user_id, other_user_id in user_id_sets:
                 txn.call_after(
-                    self.get_users_who_share_room_from_dir.invalidate,
-                    (user_id,),
+                    self.get_users_who_share_room_from_dir.invalidate, (user_id,)
                 )
                 txn.call_after(
-                    self.get_if_users_share_a_room.invalidate,
-                    (user_id, other_user_id),
+                    self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
                 )
+
         return self.runInteraction(
             "update_users_who_share_room", _update_users_who_share_room_txn
         )
@@ -409,22 +426,18 @@ class UserDirectoryStore(SQLBaseStore):
             share_private (bool): Is the room private
             user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
         """
+
         def _remove_user_who_share_room_txn(txn):
             self._simple_delete_txn(
                 txn,
                 table="users_who_share_rooms",
-                keyvalues={
-                    "user_id": user_id,
-                    "other_user_id": other_user_id,
-                },
+                keyvalues={"user_id": user_id, "other_user_id": other_user_id},
             )
             txn.call_after(
-                self.get_users_who_share_room_from_dir.invalidate,
-                (user_id,),
+                self.get_users_who_share_room_from_dir.invalidate, (user_id,)
             )
             txn.call_after(
-                self.get_if_users_share_a_room.invalidate,
-                (user_id, other_user_id),
+                self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
             )
 
         return self.runInteraction(
@@ -445,10 +458,7 @@ class UserDirectoryStore(SQLBaseStore):
         """
         return self._simple_select_one_onecol(
             table="users_who_share_rooms",
-            keyvalues={
-                "user_id": user_id,
-                "other_user_id": other_user_id,
-            },
+            keyvalues={"user_id": user_id, "other_user_id": other_user_id},
             retcol="share_private",
             allow_none=True,
             desc="get_if_users_share_a_room",
@@ -466,17 +476,12 @@ class UserDirectoryStore(SQLBaseStore):
         """
         rows = yield self._simple_select_list(
             table="users_who_share_rooms",
-            keyvalues={
-                "user_id": user_id,
-            },
-            retcols=("other_user_id", "share_private",),
+            keyvalues={"user_id": user_id},
+            retcols=("other_user_id", "share_private"),
             desc="get_users_who_share_room_with_user",
         )
 
-        defer.returnValue({
-            row["other_user_id"]: row["share_private"]
-            for row in rows
-        })
+        defer.returnValue({row["other_user_id"]: row["share_private"] for row in rows})
 
     def get_users_in_share_dir_with_room_id(self, user_id, room_id):
         """Get all user tuples that are in the users_who_share_rooms due to the
@@ -523,6 +528,7 @@ class UserDirectoryStore(SQLBaseStore):
     def delete_all_from_user_dir(self):
         """Delete the entire user directory
         """
+
         def _delete_all_from_user_dir_txn(txn):
             txn.execute("DELETE FROM user_directory")
             txn.execute("DELETE FROM user_directory_search")
@@ -532,6 +538,7 @@ class UserDirectoryStore(SQLBaseStore):
             txn.call_after(self.get_user_in_public_room.invalidate_all)
             txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all)
             txn.call_after(self.get_if_users_share_a_room.invalidate_all)
+
         return self.runInteraction(
             "delete_all_from_user_dir", _delete_all_from_user_dir_txn
         )
@@ -541,7 +548,7 @@ class UserDirectoryStore(SQLBaseStore):
         return self._simple_select_one(
             table="user_directory",
             keyvalues={"user_id": user_id},
-            retcols=("room_id", "display_name", "avatar_url",),
+            retcols=("room_id", "display_name", "avatar_url"),
             allow_none=True,
             desc="get_user_in_directory",
         )
@@ -574,7 +581,9 @@ class UserDirectoryStore(SQLBaseStore):
 
     def get_current_state_deltas(self, prev_stream_id):
         prev_stream_id = int(prev_stream_id)
-        if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
+        if not self._curr_state_delta_stream_cache.has_any_entity_changed(
+            prev_stream_id
+        ):
             return []
 
         def get_current_state_deltas_txn(txn):
@@ -608,7 +617,7 @@ class UserDirectoryStore(SQLBaseStore):
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
             """
-            txn.execute(sql, (prev_stream_id, max_stream_id,))
+            txn.execute(sql, (prev_stream_id, max_stream_id))
             return self.cursor_to_dict(txn)
 
         return self.runInteraction(
@@ -698,8 +707,11 @@ class UserDirectoryStore(SQLBaseStore):
                     display_name IS NULL,
                     avatar_url IS NULL
                 LIMIT ?
-            """ % (join_clause, where_clause)
-            args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
+            """ % (
+                join_clause,
+                where_clause,
+            )
+            args = join_args + (full_query, exact_query, prefix_query, limit + 1)
         elif isinstance(self.database_engine, Sqlite3Engine):
             search_query = _parse_query_sqlite(search_term)
 
@@ -716,7 +728,10 @@ class UserDirectoryStore(SQLBaseStore):
                     display_name IS NULL,
                     avatar_url IS NULL
                 LIMIT ?
-            """ % (join_clause, where_clause)
+            """ % (
+                join_clause,
+                where_clause,
+            )
             args = join_args + (search_query, limit + 1)
         else:
             # This should be unreachable.
@@ -728,10 +743,7 @@ class UserDirectoryStore(SQLBaseStore):
 
         limited = len(results) > limit
 
-        defer.returnValue({
-            "limited": limited,
-            "results": results,
-        })
+        defer.returnValue({"limited": limited, "results": results})
 
 
 def _parse_query_sqlite(search_term):
@@ -746,7 +758,7 @@ def _parse_query_sqlite(search_term):
 
     # Pull out the individual words, discarding any non-word characters.
     results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
-    return " & ".join("(%s* OR %s)" % (result, result,) for result in results)
+    return " & ".join("(%s* OR %s)" % (result, result) for result in results)
 
 
 def _parse_query_postgres(search_term):
@@ -759,7 +771,7 @@ def _parse_query_postgres(search_term):
     # Pull out the individual words, discarding any non-word characters.
     results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
 
-    both = " & ".join("(%s:* | %s)" % (result, result,) for result in results)
+    both = " & ".join("(%s:* | %s)" % (result, result) for result in results)
     exact = " & ".join("%s" % (result,) for result in results)
     prefix = " & ".join("%s:*" % (result,) for result in results)