summary refs log tree commit diff
path: root/synapse/storage/user_directory.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/user_directory.py')
-rw-r--r--synapse/storage/user_directory.py34
1 files changed, 29 insertions, 5 deletions
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index 2e9175f50a..ca2be9daf2 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -21,6 +21,8 @@ from synapse.api.constants import EventTypes, JoinRules
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.types import get_domain_from_id, get_localpart_from_id
 
+import re
+
 
 class UserDirectoryStore(SQLBaseStore):
 
@@ -272,17 +274,17 @@ class UserDirectoryStore(SQLBaseStore):
                     ]
                 }
         """
-
+        search_query = _parse_query(self.database_engine, search_term)
         if isinstance(self.database_engine, PostgresEngine):
             sql = """
                 SELECT user_id, display_name, avatar_url
                 FROM user_directory_search
                 INNER JOIN user_directory USING (user_id)
-                WHERE vector @@ plainto_tsquery('english', ?)
-                ORDER BY ts_rank_cd(vector, plainto_tsquery('english', ?)) DESC
+                WHERE vector @@ to_tsquery('english', ?)
+                ORDER BY ts_rank_cd(vector, to_tsquery('english', ?)) DESC
                 LIMIT ?
             """
-            args = (search_term, search_term, limit + 1,)
+            args = (search_query, search_query, limit + 1,)
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = """
                 SELECT user_id, display_name, avatar_url
@@ -292,7 +294,7 @@ class UserDirectoryStore(SQLBaseStore):
                 ORDER BY rank(matchinfo(user_directory)) DESC
                 LIMIT ?
             """
-            args = (search_term, limit + 1)
+            args = (search_query, limit + 1)
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
@@ -307,3 +309,25 @@ class UserDirectoryStore(SQLBaseStore):
             "limited": limited,
             "results": results,
         })
+
+
+def _parse_query(database_engine, search_term):
+    """Takes a plain unicode string from the user and converts it into a form
+    that can be passed to database.
+    We use this so that we can add prefix matching, which isn't something
+    that is supported by default.
+
+    We specifically add both a prefix and non prefix matching term so that
+    exact matches get ranked higher.
+    """
+
+    # Pull out the individual words, discarding any non-word characters.
+    results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+
+    if isinstance(database_engine, PostgresEngine):
+        return " & ".join("%s:* & %s" % (result, result,) for result in results)
+    elif isinstance(database_engine, Sqlite3Engine):
+        return " & ".join("%s* & %s" % (result, result,) for result in results)
+    else:
+        # This should be unreachable.
+        raise Exception("Unrecognized database engine")