summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/user_directory.py4
-rw-r--r--synapse/rest/client/v2_alpha/user_directory.py8
-rw-r--r--synapse/storage/user_directory.py61
3 files changed, 47 insertions, 26 deletions
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index aa8af95177..8928786fd6 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -71,7 +71,7 @@ class UserDirectoyHandler(object):
         # we start populating the user directory
         self.clock.call_later(0, self.notify_new_event)
 
-    def search_users(self, search_term, limit):
+    def search_users(self, user_id, search_term, limit):
         """Searches for users in directory
 
         Returns:
@@ -88,7 +88,7 @@ class UserDirectoyHandler(object):
                     ]
                 }
         """
-        return self.store.search_user_dir(search_term, limit)
+        return self.store.search_user_dir(user_id, search_term, limit)
 
     @defer.inlineCallbacks
     def notify_new_event(self):
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index 17d3dffc8f..6e012da4aa 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -55,7 +55,9 @@ class UserDirectorySearchRestServlet(RestServlet):
                     ]
                 }
         """
-        yield self.auth.get_user_by_req(request, allow_guest=False)
+        requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+        user_id = requester.user.to_string()
+
         body = parse_json_object_from_request(request)
 
         limit = body.get("limit", 10)
@@ -66,7 +68,9 @@ class UserDirectorySearchRestServlet(RestServlet):
         except:
             raise SynapseError(400, "`search_term` is required field")
 
-        results = yield self.user_directory_handler.search_users(search_term, limit)
+        results = yield self.user_directory_handler.search_users(
+            user_id, search_term, limit,
+        )
 
         defer.returnValue((200, results))
 
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index 2a17cbc9e9..52b184fe78 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -611,7 +611,7 @@ class UserDirectoryStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def search_user_dir(self, search_term, limit):
+    def search_user_dir(self, user_id, search_term, limit):
         """Searches for users in directory
 
         Returns:
@@ -637,46 +637,63 @@ class UserDirectoryStore(SQLBaseStore):
             # The array of numbers are the weights for the various part of the
             # search: (domain, _, display name, localpart)
             sql = """
-                SELECT user_id, display_name, avatar_url
+                SELECT d.user_id, display_name, avatar_url
                 FROM user_directory_search
-                INNER JOIN user_directory USING (user_id)
-                INNER JOIN users_in_pubic_room USING (user_id)
-                WHERE vector @@ to_tsquery('english', ?)
+                INNER JOIN user_directory AS d USING (user_id)
+                LEFT JOIN users_in_pubic_room AS p USING (user_id)
+                LEFT JOIN (
+                    SELECT other_user_id AS user_id FROM users_who_share_rooms
+                    WHERE user_id = ? AND share_private
+                ) AS s USING (user_id)
+                WHERE
+                    (s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
+                    AND vector @@ to_tsquery('english', ?)
                 ORDER BY
-                    2 * ts_rank_cd(
-                        '{0.1, 0.1, 0.9, 1.0}',
-                        vector,
-                        to_tsquery('english', ?),
-                        8
-                    )
-                    + ts_rank_cd(
-                        '{0.1, 0.1, 0.9, 1.0}',
-                        vector,
-                        to_tsquery('english', ?),
-                        8
+                    (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
+                    * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
+                    * (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END)
+                    * (
+                        3 * ts_rank_cd(
+                            '{0.1, 0.1, 0.9, 1.0}',
+                            vector,
+                            to_tsquery('english', ?),
+                            8
+                        )
+                        + ts_rank_cd(
+                            '{0.1, 0.1, 0.9, 1.0}',
+                            vector,
+                            to_tsquery('english', ?),
+                            8
+                        )
                     )
                     DESC,
                     display_name IS NULL,
                     avatar_url IS NULL
                 LIMIT ?
             """
-            args = (full_query, exact_query, prefix_query, limit + 1,)
+            args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
         elif isinstance(self.database_engine, Sqlite3Engine):
             search_query = _parse_query_sqlite(search_term)
 
             sql = """
-                SELECT user_id, display_name, avatar_url
+                SELECT d.user_id, display_name, avatar_url
                 FROM user_directory_search
-                INNER JOIN user_directory USING (user_id)
-                INNER JOIN users_in_pubic_room USING (user_id)
-                WHERE value MATCH ?
+                INNER JOIN user_directory AS d USING (user_id)
+                LEFT JOIN users_in_pubic_room AS p USING (user_id)
+                LEFT JOIN (
+                    SELECT other_user_id AS user_id FROM users_who_share_rooms
+                    WHERE user_id = ? AND share_private
+                ) AS s USING (user_id)
+                WHERE
+                    (s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
+                    AND value MATCH ?
                 ORDER BY
                     rank(matchinfo(user_directory_search)) DESC,
                     display_name IS NULL,
                     avatar_url IS NULL
                 LIMIT ?
             """
-            args = (search_query, limit + 1)
+            args = (user_id, search_query, limit + 1)
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")