diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 380270b009..39f600f53c 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -20,6 +20,7 @@ from synapse.api.errors import SynapseError
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging
+import re
logger = logging.getLogger(__name__)
@@ -139,7 +140,10 @@ class SearchStore(BackgroundUpdateStore):
list of dicts
"""
clauses = []
- args = []
+
+ search_query = search_query = _parse_query(self.database_engine, search_term)
+
+ args = [search_query]
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
@@ -161,7 +165,7 @@ class SearchStore(BackgroundUpdateStore):
if isinstance(self.database_engine, PostgresEngine):
sql = (
"SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id"
- " FROM plainto_tsquery('english', ?) as query, event_search"
+ " FROM to_tsquery('english', ?) as query, event_search"
" WHERE vector @@ query"
)
elif isinstance(self.database_engine, Sqlite3Engine):
@@ -182,7 +186,7 @@ class SearchStore(BackgroundUpdateStore):
sql += " ORDER BY rank DESC LIMIT 500"
results = yield self._execute(
- "search_msgs", self.cursor_to_dict, sql, *([search_term] + args)
+ "search_msgs", self.cursor_to_dict, sql, *args
)
results = filter(lambda row: row["room_id"] in room_ids, results)
@@ -194,21 +198,28 @@ class SearchStore(BackgroundUpdateStore):
for ev in events
}
- defer.returnValue([
- {
- "event": event_map[r["event_id"]],
- "rank": r["rank"],
- }
- for r in results
- if r["event_id"] in event_map
- ])
+ highlights = None
+ if isinstance(self.database_engine, PostgresEngine):
+ highlights = yield self._find_highlights_in_postgres(search_query, events)
+
+ defer.returnValue({
+ "results": [
+ {
+ "event": event_map[r["event_id"]],
+ "rank": r["rank"],
+ }
+ for r in results
+ if r["event_id"] in event_map
+ ],
+ "highlights": highlights,
+ })
@defer.inlineCallbacks
- def search_room(self, room_id, search_term, keys, limit, pagination_token=None):
+ def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
"""Performs a full text search over events with given keys.
Args:
- room_id (str): The room_id to search in
+ room_id (list): The room_ids to search in
search_term (str): Search term to search for
keys (list): List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
@@ -218,7 +229,18 @@ class SearchStore(BackgroundUpdateStore):
list of dicts
"""
clauses = []
- args = [search_term, room_id]
+
+ search_query = search_query = _parse_query(self.database_engine, search_term)
+
+ args = [search_query]
+
+ # Make sure we don't explode because the person is in too many rooms.
+ # We filter the results below regardless.
+ if len(room_ids) < 500:
+ clauses.append(
+ "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
+ )
+ args.extend(room_ids)
local_clauses = []
for key in keys:
@@ -231,25 +253,25 @@ class SearchStore(BackgroundUpdateStore):
if pagination_token:
try:
- topo, stream = pagination_token.split(",")
- topo = int(topo)
+ origin_server_ts, stream = pagination_token.split(",")
+ origin_server_ts = int(origin_server_ts)
stream = int(stream)
except:
raise SynapseError(400, "Invalid pagination token")
clauses.append(
- "(topological_ordering < ?"
- " OR (topological_ordering = ? AND stream_ordering < ?))"
+ "(origin_server_ts < ?"
+ " OR (origin_server_ts = ? AND stream_ordering < ?))"
)
- args.extend([topo, topo, stream])
+ args.extend([origin_server_ts, origin_server_ts, stream])
if isinstance(self.database_engine, PostgresEngine):
sql = (
"SELECT ts_rank_cd(vector, query) as rank,"
- " topological_ordering, stream_ordering, room_id, event_id"
- " FROM plainto_tsquery('english', ?) as query, event_search"
+ " origin_server_ts, stream_ordering, room_id, event_id"
+ " FROM to_tsquery('english', ?) as query, event_search"
" NATURAL JOIN events"
- " WHERE vector @@ query AND room_id = ?"
+ " WHERE vector @@ query AND "
)
elif isinstance(self.database_engine, Sqlite3Engine):
# We use CROSS JOIN here to ensure we use the right indexes.
@@ -262,24 +284,23 @@ class SearchStore(BackgroundUpdateStore):
# MATCH unless it uses the full text search index
sql = (
"SELECT rank(matchinfo) as rank, room_id, event_id,"
- " topological_ordering, stream_ordering"
+ " origin_server_ts, stream_ordering"
" FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo"
" FROM event_search"
" WHERE value MATCH ?"
" )"
" CROSS JOIN events USING (event_id)"
- " WHERE room_id = ?"
+ " WHERE "
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
- for clause in clauses:
- sql += " AND " + clause
+ sql += " AND ".join(clauses)
# We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database.
- sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
+ sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
args.append(limit)
@@ -287,6 +308,8 @@ class SearchStore(BackgroundUpdateStore):
"search_rooms", self.cursor_to_dict, sql, *args
)
+ results = filter(lambda row: row["room_id"] in room_ids, results)
+
events = yield self._get_events([r["event_id"] for r in results])
event_map = {
@@ -294,14 +317,110 @@ class SearchStore(BackgroundUpdateStore):
for ev in events
}
- defer.returnValue([
- {
- "event": event_map[r["event_id"]],
- "rank": r["rank"],
- "pagination_token": "%s,%s" % (
- r["topological_ordering"], r["stream_ordering"]
- ),
- }
- for r in results
- if r["event_id"] in event_map
- ])
+ highlights = None
+ if isinstance(self.database_engine, PostgresEngine):
+ highlights = yield self._find_highlights_in_postgres(search_query, events)
+
+ defer.returnValue({
+ "results": [
+ {
+ "event": event_map[r["event_id"]],
+ "rank": r["rank"],
+ "pagination_token": "%s,%s" % (
+ r["origin_server_ts"], r["stream_ordering"]
+ ),
+ }
+ for r in results
+ if r["event_id"] in event_map
+ ],
+ "highlights": highlights,
+ })
+
+ def _find_highlights_in_postgres(self, search_query, events):
+ """Given a list of events and a search term, return a list of words
+ that match from the content of the event.
+
+ This is used to give a list of words that clients can match against to
+ highlight the matching parts.
+
+ Args:
+ search_query (str)
+ events (list): A list of events
+
+ Returns:
+ deferred : A set of strings.
+ """
+ def f(txn):
+ highlight_words = set()
+ for event in events:
+ # As a hack we simply join values of all possible keys. This is
+ # fine since we're only using them to find possible highlights.
+ values = []
+ for key in ("body", "name", "topic"):
+ v = event.content.get(key, None)
+ if v:
+ values.append(v)
+
+ if not values:
+ continue
+
+ value = " ".join(values)
+
+ # We need to find some values for StartSel and StopSel that
+ # aren't in the value so that we can pick results out.
+ start_sel = "<"
+ stop_sel = ">"
+
+ while start_sel in value:
+ start_sel += "<"
+ while stop_sel in value:
+ stop_sel += ">"
+
+ query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
+ _to_postgres_options({
+ "StartSel": start_sel,
+ "StopSel": stop_sel,
+ "MaxFragments": "50",
+ })
+ )
+ txn.execute(query, (value, search_query,))
+ headline, = txn.fetchall()[0]
+
+ # Now we need to pick the possible highlights out of the haedline
+ # result.
+ matcher_regex = "%s(.*?)%s" % (
+ re.escape(start_sel),
+ re.escape(stop_sel),
+ )
+
+ res = re.findall(matcher_regex, headline)
+ highlight_words.update([r.lower() for r in res])
+
+ return highlight_words
+
+ return self.runInteraction("_find_highlights", f)
+
+
+def _to_postgres_options(options_dict):
+ return "'%s'" % (
+ ",".join("%s=%s" % (k, v) for k, v in options_dict.items()),
+ )
+
+
+def _parse_query(database_engine, search_term):
+ """Takes a plain unicode string from the user and converts it into a form
+ that can be passed to database.
+ We use this so that we can add prefix matching, which isn't something
+ that is supported by default.
+ """
+
+ # Pull out the individual words, discarding any non-word characters.
+ results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+
+ if isinstance(database_engine, PostgresEngine):
+ return " & ".join(result + ":*" for result in results)
+ elif isinstance(database_engine, Sqlite3Engine):
+ return " & ".join(result + "*" for result in results)
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
|