diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index 2e9175f50a..ca2be9daf2 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -21,6 +21,8 @@ from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id, get_localpart_from_id
+import re
+
class UserDirectoryStore(SQLBaseStore):
@@ -272,17 +274,17 @@ class UserDirectoryStore(SQLBaseStore):
]
}
"""
-
+ search_query = _parse_query(self.database_engine, search_term)
if isinstance(self.database_engine, PostgresEngine):
sql = """
SELECT user_id, display_name, avatar_url
FROM user_directory_search
INNER JOIN user_directory USING (user_id)
- WHERE vector @@ plainto_tsquery('english', ?)
- ORDER BY ts_rank_cd(vector, plainto_tsquery('english', ?)) DESC
+ WHERE vector @@ to_tsquery('english', ?)
+ ORDER BY ts_rank_cd(vector, to_tsquery('english', ?)) DESC
LIMIT ?
"""
- args = (search_term, search_term, limit + 1,)
+ args = (search_query, search_query, limit + 1,)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = """
SELECT user_id, display_name, avatar_url
@@ -292,7 +294,7 @@ class UserDirectoryStore(SQLBaseStore):
ORDER BY rank(matchinfo(user_directory)) DESC
LIMIT ?
"""
- args = (search_term, limit + 1)
+ args = (search_query, limit + 1)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@@ -307,3 +309,25 @@ class UserDirectoryStore(SQLBaseStore):
"limited": limited,
"results": results,
})
+
+
+def _parse_query(database_engine, search_term):
+ """Takes a plain unicode string from the user and converts it into a form
+ that can be passed to database.
+ We use this so that we can add prefix matching, which isn't something
+ that is supported by default.
+
+ We specifically add both a prefix and non prefix matching term so that
+ exact matches get ranked higher.
+ """
+
+ # Pull out the individual words, discarding any non-word characters.
+ results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+
+ if isinstance(database_engine, PostgresEngine):
+ return " & ".join("%s:* & %s" % (result, result,) for result in results)
+ elif isinstance(database_engine, Sqlite3Engine):
+ return " & ".join("%s* & %s" % (result, result,) for result in results)
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
|