summary refs log tree commit diff
path: root/synapse/storage/search.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/search.py')
-rw-r--r--synapse/storage/search.py263
1 files changed, 223 insertions, 40 deletions
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 380270b009..6cb5e73b6e 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -20,6 +20,7 @@ from synapse.api.errors import SynapseError
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
 import logging
+import re
 
 
 logger = logging.getLogger(__name__)
@@ -84,6 +85,11 @@ class SearchStore(BackgroundUpdateStore):
                     # skip over it.
                     continue
 
+                if not isinstance(value, basestring):
+                    # If the event body, name or topic isn't a string
+                    # then skip over it
+                    continue
+
                 event_search_rows.append((event_id, room_id, key, value))
 
             if isinstance(self.database_engine, PostgresEngine):
@@ -139,6 +145,9 @@ class SearchStore(BackgroundUpdateStore):
             list of dicts
         """
         clauses = []
+
+        search_query = search_query = _parse_query(self.database_engine, search_term)
+
         args = []
 
         # Make sure we don't explode because the person is in too many rooms.
@@ -158,18 +167,36 @@ class SearchStore(BackgroundUpdateStore):
             "(%s)" % (" OR ".join(local_clauses),)
         )
 
+        count_args = args
+        count_clauses = clauses
+
         if isinstance(self.database_engine, PostgresEngine):
             sql = (
-                "SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id"
-                " FROM plainto_tsquery('english', ?) as query, event_search"
-                " WHERE vector @@ query"
+                "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank,"
+                " room_id, event_id"
+                " FROM event_search"
+                " WHERE vector @@ to_tsquery('english', ?)"
+            )
+            args = [search_query, search_query] + args
+
+            count_sql = (
+                "SELECT room_id, count(*) as count FROM event_search"
+                " WHERE vector @@ to_tsquery('english', ?)"
             )
+            count_args = [search_query] + count_args
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = (
                 "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
                 " FROM event_search"
                 " WHERE value MATCH ?"
             )
+            args = [search_query] + args
+
+            count_sql = (
+                "SELECT room_id, count(*) as count FROM event_search"
+                " WHERE value MATCH ?"
+            )
+            count_args = [search_term] + count_args
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
@@ -177,12 +204,15 @@ class SearchStore(BackgroundUpdateStore):
         for clause in clauses:
             sql += " AND " + clause
 
+        for clause in count_clauses:
+            count_sql += " AND " + clause
+
         # We add an arbitrary limit here to ensure we don't try to pull the
         # entire table from the database.
         sql += " ORDER BY rank DESC LIMIT 500"
 
         results = yield self._execute(
-            "search_msgs", self.cursor_to_dict, sql, *([search_term] + args)
+            "search_msgs", self.cursor_to_dict, sql, *args
         )
 
         results = filter(lambda row: row["room_id"] in room_ids, results)
@@ -194,21 +224,37 @@ class SearchStore(BackgroundUpdateStore):
             for ev in events
         }
 
-        defer.returnValue([
-            {
-                "event": event_map[r["event_id"]],
-                "rank": r["rank"],
-            }
-            for r in results
-            if r["event_id"] in event_map
-        ])
+        highlights = None
+        if isinstance(self.database_engine, PostgresEngine):
+            highlights = yield self._find_highlights_in_postgres(search_query, events)
+
+        count_sql += " GROUP BY room_id"
+
+        count_results = yield self._execute(
+            "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+        )
+
+        count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+
+        defer.returnValue({
+            "results": [
+                {
+                    "event": event_map[r["event_id"]],
+                    "rank": r["rank"],
+                }
+                for r in results
+                if r["event_id"] in event_map
+            ],
+            "highlights": highlights,
+            "count": count,
+        })
 
     @defer.inlineCallbacks
-    def search_room(self, room_id, search_term, keys, limit, pagination_token=None):
+    def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
         """Performs a full text search over events with given keys.
 
         Args:
-            room_id (str): The room_id to search in
+            room_id (list): The room_ids to search in
             search_term (str): Search term to search for
             keys (list): List of keys to search in, currently supports
                 "content.body", "content.name", "content.topic"
@@ -218,7 +264,18 @@ class SearchStore(BackgroundUpdateStore):
             list of dicts
         """
         clauses = []
-        args = [search_term, room_id]
+
+        search_query = search_query = _parse_query(self.database_engine, search_term)
+
+        args = []
+
+        # Make sure we don't explode because the person is in too many rooms.
+        # We filter the results below regardless.
+        if len(room_ids) < 500:
+            clauses.append(
+                "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
+            )
+            args.extend(room_ids)
 
         local_clauses = []
         for key in keys:
@@ -229,28 +286,40 @@ class SearchStore(BackgroundUpdateStore):
             "(%s)" % (" OR ".join(local_clauses),)
         )
 
+        # take copies of the current args and clauses lists, before adding
+        # pagination clauses to main query.
+        count_args = list(args)
+        count_clauses = list(clauses)
+
         if pagination_token:
             try:
-                topo, stream = pagination_token.split(",")
-                topo = int(topo)
+                origin_server_ts, stream = pagination_token.split(",")
+                origin_server_ts = int(origin_server_ts)
                 stream = int(stream)
             except:
                 raise SynapseError(400, "Invalid pagination token")
 
             clauses.append(
-                "(topological_ordering < ?"
-                " OR (topological_ordering = ? AND stream_ordering < ?))"
+                "(origin_server_ts < ?"
+                " OR (origin_server_ts = ? AND stream_ordering < ?))"
             )
-            args.extend([topo, topo, stream])
+            args.extend([origin_server_ts, origin_server_ts, stream])
 
         if isinstance(self.database_engine, PostgresEngine):
             sql = (
-                "SELECT ts_rank_cd(vector, query) as rank,"
-                " topological_ordering, stream_ordering, room_id, event_id"
-                " FROM plainto_tsquery('english', ?) as query, event_search"
+                "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank,"
+                " origin_server_ts, stream_ordering, room_id, event_id"
+                " FROM event_search"
                 " NATURAL JOIN events"
-                " WHERE vector @@ query AND room_id = ?"
+                " WHERE vector @@ to_tsquery('english', ?) AND "
+            )
+            args = [search_query, search_query] + args
+
+            count_sql = (
+                "SELECT room_id, count(*) as count FROM event_search"
+                " WHERE vector @@ to_tsquery('english', ?) AND "
             )
+            count_args = [search_query] + count_args
         elif isinstance(self.database_engine, Sqlite3Engine):
             # We use CROSS JOIN here to ensure we use the right indexes.
             # https://sqlite.org/optoverview.html#crossjoin
@@ -262,24 +331,31 @@ class SearchStore(BackgroundUpdateStore):
             # MATCH unless it uses the full text search index
             sql = (
                 "SELECT rank(matchinfo) as rank, room_id, event_id,"
-                " topological_ordering, stream_ordering"
+                " origin_server_ts, stream_ordering"
                 " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo"
                 " FROM event_search"
                 " WHERE value MATCH ?"
                 " )"
                 " CROSS JOIN events USING (event_id)"
-                " WHERE room_id = ?"
+                " WHERE "
             )
+            args = [search_query] + args
+
+            count_sql = (
+                "SELECT room_id, count(*) as count FROM event_search"
+                " WHERE value MATCH ? AND "
+            )
+            count_args = [search_term] + count_args
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
 
-        for clause in clauses:
-            sql += " AND " + clause
+        sql += " AND ".join(clauses)
+        count_sql += " AND ".join(count_clauses)
 
         # We add an arbitrary limit here to ensure we don't try to pull the
         # entire table from the database.
-        sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
+        sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
 
         args.append(limit)
 
@@ -287,6 +363,8 @@ class SearchStore(BackgroundUpdateStore):
             "search_rooms", self.cursor_to_dict, sql, *args
         )
 
+        results = filter(lambda row: row["room_id"] in room_ids, results)
+
         events = yield self._get_events([r["event_id"] for r in results])
 
         event_map = {
@@ -294,14 +372,119 @@ class SearchStore(BackgroundUpdateStore):
             for ev in events
         }
 
-        defer.returnValue([
-            {
-                "event": event_map[r["event_id"]],
-                "rank": r["rank"],
-                "pagination_token": "%s,%s" % (
-                    r["topological_ordering"], r["stream_ordering"]
-                ),
-            }
-            for r in results
-            if r["event_id"] in event_map
-        ])
+        highlights = None
+        if isinstance(self.database_engine, PostgresEngine):
+            highlights = yield self._find_highlights_in_postgres(search_query, events)
+
+        count_sql += " GROUP BY room_id"
+
+        count_results = yield self._execute(
+            "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+        )
+
+        count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+
+        defer.returnValue({
+            "results": [
+                {
+                    "event": event_map[r["event_id"]],
+                    "rank": r["rank"],
+                    "pagination_token": "%s,%s" % (
+                        r["origin_server_ts"], r["stream_ordering"]
+                    ),
+                }
+                for r in results
+                if r["event_id"] in event_map
+            ],
+            "highlights": highlights,
+            "count": count,
+        })
+
+    def _find_highlights_in_postgres(self, search_query, events):
+        """Given a list of events and a search term, return a list of words
+        that match from the content of the event.
+
+        This is used to give a list of words that clients can match against to
+        highlight the matching parts.
+
+        Args:
+            search_query (str)
+            events (list): A list of events
+
+        Returns:
+            deferred : A set of strings.
+        """
+        def f(txn):
+            highlight_words = set()
+            for event in events:
+                # As a hack we simply join values of all possible keys. This is
+                # fine since we're only using them to find possible highlights.
+                values = []
+                for key in ("body", "name", "topic"):
+                    v = event.content.get(key, None)
+                    if v:
+                        values.append(v)
+
+                if not values:
+                    continue
+
+                value = " ".join(values)
+
+                # We need to find some values for StartSel and StopSel that
+                # aren't in the value so that we can pick results out.
+                start_sel = "<"
+                stop_sel = ">"
+
+                while start_sel in value:
+                    start_sel += "<"
+                while stop_sel in value:
+                    stop_sel += ">"
+
+                query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
+                    _to_postgres_options({
+                        "StartSel": start_sel,
+                        "StopSel": stop_sel,
+                        "MaxFragments": "50",
+                    })
+                )
+                txn.execute(query, (value, search_query,))
+                headline, = txn.fetchall()[0]
+
+                # Now we need to pick the possible highlights out of the haedline
+                # result.
+                matcher_regex = "%s(.*?)%s" % (
+                    re.escape(start_sel),
+                    re.escape(stop_sel),
+                )
+
+                res = re.findall(matcher_regex, headline)
+                highlight_words.update([r.lower() for r in res])
+
+            return highlight_words
+
+        return self.runInteraction("_find_highlights", f)
+
+
+def _to_postgres_options(options_dict):
+    return "'%s'" % (
+        ",".join("%s=%s" % (k, v) for k, v in options_dict.items()),
+    )
+
+
+def _parse_query(database_engine, search_term):
+    """Takes a plain unicode string from the user and converts it into a form
+    that can be passed to database.
+    We use this so that we can add prefix matching, which isn't something
+    that is supported by default.
+    """
+
+    # Pull out the individual words, discarding any non-word characters.
+    results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+
+    if isinstance(database_engine, PostgresEngine):
+        return " & ".join(result + ":*" for result in results)
+    elif isinstance(database_engine, Sqlite3Engine):
+        return " & ".join(result + "*" for result in results)
+    else:
+        # This should be unreachable.
+        raise Exception("Unrecognized database engine")