diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/storage/user_directory.py | 34 |
1 files changed, 29 insertions, 5 deletions
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index 2e9175f50a..ca2be9daf2 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -21,6 +21,8 @@ from synapse.api.constants import EventTypes, JoinRules from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id, get_localpart_from_id +import re + class UserDirectoryStore(SQLBaseStore): @@ -272,17 +274,17 @@ class UserDirectoryStore(SQLBaseStore): ] } """ - + search_query = _parse_query(self.database_engine, search_term) if isinstance(self.database_engine, PostgresEngine): sql = """ SELECT user_id, display_name, avatar_url FROM user_directory_search INNER JOIN user_directory USING (user_id) - WHERE vector @@ plainto_tsquery('english', ?) - ORDER BY ts_rank_cd(vector, plainto_tsquery('english', ?)) DESC + WHERE vector @@ to_tsquery('english', ?) + ORDER BY ts_rank_cd(vector, to_tsquery('english', ?)) DESC LIMIT ? """ - args = (search_term, search_term, limit + 1,) + args = (search_query, search_query, limit + 1,) elif isinstance(self.database_engine, Sqlite3Engine): sql = """ SELECT user_id, display_name, avatar_url @@ -292,7 +294,7 @@ class UserDirectoryStore(SQLBaseStore): ORDER BY rank(matchinfo(user_directory)) DESC LIMIT ? """ - args = (search_term, limit + 1) + args = (search_query, limit + 1) else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -307,3 +309,25 @@ class UserDirectoryStore(SQLBaseStore): "limited": limited, "results": results, }) + + +def _parse_query(database_engine, search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. + + We specifically add both a prefix and non prefix matching term so that + exact matches get ranked higher. + """ + + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + + if isinstance(database_engine, PostgresEngine): + return " & ".join("%s:* & %s" % (result, result,) for result in results) + elif isinstance(database_engine, Sqlite3Engine): + return " & ".join("%s* & %s" % (result, result,) for result in results) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") |