summary refs log tree commit diff
path: root/synapse/storage/user_directory.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/user_directory.py')
-rw-r--r--synapse/storage/user_directory.py49
1 files changed, 49 insertions, 0 deletions
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index bcf24fa4d0..6a4bf63f0d 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -147,6 +147,53 @@ class UserDirectoryStore(SQLBaseStore):
             updatevalues={"room_id": room_id},
             desc="update_user_in_user_dir",
         )
+        self.get_user_in_directory.invalidate((user_id,))
+
+    def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
+        def _update_profile_in_user_dir_txn(txn):
+            self._simple_update_one_txn(
+                txn,
+                table="user_directory",
+                keyvalues={"user_id": user_id},
+                updatevalues={"display_name": display_name, "avatar_url": avatar_url},
+            )
+
+            if isinstance(self.database_engine, PostgresEngine):
+                # We weight the loclpart most highly, then display name and finally
+                # server name
+                sql = """
+                    UPDATE user_directory_search
+                    SET vector = setweight(to_tsvector('english', ?), 'A')
+                        || setweight(to_tsvector('english', ?), 'D')
+                        || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+                    WHERE user_id = ?
+                """
+                args = (
+                    get_localpart_from_id(user_id), get_domain_from_id(user_id),
+                    display_name,
+                    user_id,
+                )
+            elif isinstance(self.database_engine, Sqlite3Engine):
+                sql = """
+                    UPDATE user_directory_search
+                    set value = ?
+                    WHERE user_id = ?
+                """
+                args = (
+                    "%s %s" % (user_id, display_name,) if display_name else user_id,
+                    user_id,
+                )
+            else:
+                # This should be unreachable.
+                raise Exception("Unrecognized database engine")
+
+            txn.execute(sql, args)
+
+            txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
+
+        return self.runInteraction(
+            "update_profile_in_user_dir", _update_profile_in_user_dir_txn
+        )
 
     @defer.inlineCallbacks
     def update_user_in_public_user_list(self, user_id, room_id):
@@ -156,6 +203,7 @@ class UserDirectoryStore(SQLBaseStore):
             updatevalues={"room_id": room_id},
             desc="update_user_in_public_user_list",
         )
+        self.get_user_in_public_room.invalidate((user_id,))
 
     def remove_from_user_dir(self, user_id):
         def _remove_from_user_dir_txn(txn):
@@ -233,6 +281,7 @@ class UserDirectoryStore(SQLBaseStore):
             txn.execute("DELETE FROM user_directory_search")
             txn.execute("DELETE FROM users_in_pubic_room")
             txn.call_after(self.get_user_in_directory.invalidate_all)
+            txn.call_after(self.get_user_in_public_room.invalidate_all)
         return self.runInteraction(
             "delete_all_from_user_dir", _delete_all_from_user_dir_txn
         )