diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 20a62d07ff..39f600f53c 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -140,7 +140,10 @@ class SearchStore(BackgroundUpdateStore):
list of dicts
"""
clauses = []
- args = []
+
+ search_query = search_query = _parse_query(self.database_engine, search_term)
+
+ args = [search_query]
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
@@ -162,7 +165,7 @@ class SearchStore(BackgroundUpdateStore):
if isinstance(self.database_engine, PostgresEngine):
sql = (
"SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id"
- " FROM plainto_tsquery('english', ?) as query, event_search"
+ " FROM to_tsquery('english', ?) as query, event_search"
" WHERE vector @@ query"
)
elif isinstance(self.database_engine, Sqlite3Engine):
@@ -183,7 +186,7 @@ class SearchStore(BackgroundUpdateStore):
sql += " ORDER BY rank DESC LIMIT 500"
results = yield self._execute(
- "search_msgs", self.cursor_to_dict, sql, *([search_term] + args)
+ "search_msgs", self.cursor_to_dict, sql, *args
)
results = filter(lambda row: row["room_id"] in room_ids, results)
@@ -197,7 +200,7 @@ class SearchStore(BackgroundUpdateStore):
highlights = None
if isinstance(self.database_engine, PostgresEngine):
- highlights = yield self._find_highlights_in_postgres(search_term, events)
+ highlights = yield self._find_highlights_in_postgres(search_query, events)
defer.returnValue({
"results": [
@@ -226,7 +229,10 @@ class SearchStore(BackgroundUpdateStore):
list of dicts
"""
clauses = []
- args = [search_term]
+
+ search_query = search_query = _parse_query(self.database_engine, search_term)
+
+ args = [search_query]
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
@@ -263,7 +269,7 @@ class SearchStore(BackgroundUpdateStore):
sql = (
"SELECT ts_rank_cd(vector, query) as rank,"
" origin_server_ts, stream_ordering, room_id, event_id"
- " FROM plainto_tsquery('english', ?) as query, event_search"
+ " FROM to_tsquery('english', ?) as query, event_search"
" NATURAL JOIN events"
" WHERE vector @@ query AND "
)
@@ -313,7 +319,7 @@ class SearchStore(BackgroundUpdateStore):
highlights = None
if isinstance(self.database_engine, PostgresEngine):
- highlights = yield self._find_highlights_in_postgres(search_term, events)
+ highlights = yield self._find_highlights_in_postgres(search_query, events)
defer.returnValue({
"results": [
@@ -330,7 +336,7 @@ class SearchStore(BackgroundUpdateStore):
"highlights": highlights,
})
- def _find_highlights_in_postgres(self, search_term, events):
+ def _find_highlights_in_postgres(self, search_query, events):
"""Given a list of events and a search term, return a list of words
that match from the content of the event.
@@ -338,7 +344,7 @@ class SearchStore(BackgroundUpdateStore):
highlight the matching parts.
Args:
- search_term (str)
+ search_query (str)
events (list): A list of events
Returns:
@@ -370,14 +376,14 @@ class SearchStore(BackgroundUpdateStore):
while stop_sel in value:
stop_sel += ">"
- query = "SELECT ts_headline(?, plainto_tsquery('english', ?), %s)" % (
+ query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
_to_postgres_options({
"StartSel": start_sel,
"StopSel": stop_sel,
"MaxFragments": "50",
})
)
- txn.execute(query, (value, search_term,))
+ txn.execute(query, (value, search_query,))
headline, = txn.fetchall()[0]
# Now we need to pick the possible highlights out of the haedline
@@ -399,3 +405,22 @@ def _to_postgres_options(options_dict):
return "'%s'" % (
",".join("%s=%s" % (k, v) for k, v in options_dict.items()),
)
+
+
+def _parse_query(database_engine, search_term):
+ """Takes a plain unicode string from the user and converts it into a form
+ that can be passed to database.
+ We use this so that we can add prefix matching, which isn't something
+ that is supported by default.
+ """
+
+ # Pull out the individual words, discarding any non-word characters.
+ results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+
+ if isinstance(database_engine, PostgresEngine):
+ return " & ".join(result + ":*" for result in results)
+ elif isinstance(database_engine, Sqlite3Engine):
+ return " & ".join(result + "*" for result in results)
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
|