summary refs log tree commit diff
path: root/synapse/storage/search.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/search.py')
-rw-r--r--synapse/storage/search.py120
1 files changed, 103 insertions, 17 deletions
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index cdf003502f..3cea2011fa 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -16,18 +16,13 @@
 from twisted.internet import defer
 
 from _base import SQLBaseStore
+from synapse.api.errors import SynapseError
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
-from collections import namedtuple
+import logging
 
-"""The result of a search.
 
-Fields:
-    rank_map (dict): Mapping event_id -> rank
-    event_map (dict): Mapping event_id -> event
-    pagination_token (str): Pagination token
-"""
-SearchResult = namedtuple("SearchResult", ("rank_map", "event_map", "pagination_token"))
+logger = logging.getLogger(__name__)
 
 
 class SearchStore(SQLBaseStore):
@@ -42,7 +37,7 @@ class SearchStore(SQLBaseStore):
                 "content.body", "content.name", "content.topic"
 
         Returns:
-            SearchResult
+            list of dicts
         """
         clauses = []
         args = []
@@ -100,12 +95,103 @@ class SearchStore(SQLBaseStore):
             for ev in events
         }
 
-        defer.returnValue(SearchResult(
+        defer.returnValue([
             {
-                r["event_id"]: r["rank"]
-                for r in results
-                if r["event_id"] in event_map
-            },
-            event_map,
-            None
-        ))
+                "event": event_map[r["event_id"]],
+                "rank": r["rank"],
+            }
+            for r in results
+            if r["event_id"] in event_map
+        ])
+
+    @defer.inlineCallbacks
+    def search_room(self, room_id, search_term, keys, limit, pagination_token=None):
+        """Performs a full text search over events with given keys.
+
+        Args:
+            room_id (str): The room_id to search in
+            search_term (str): Search term to search for
+            keys (list): List of keys to search in, currently supports
+                "content.body", "content.name", "content.topic"
+            pagination_token (str): A pagination token previously returned
+
+        Returns:
+            list of dicts
+        """
+        clauses = []
+        args = [search_term, room_id]
+
+        local_clauses = []
+        for key in keys:
+            local_clauses.append("key = ?")
+            args.append(key)
+
+        clauses.append(
+            "(%s)" % (" OR ".join(local_clauses),)
+        )
+
+        if pagination_token:
+            try:
+                topo, stream = pagination_token.split(",")
+                topo = int(topo)
+                stream = int(stream)
+            except:
+                raise SynapseError(400, "Invalid pagination token")
+
+            clauses.append(
+                "(topological_ordering < ?"
+                " OR (topological_ordering = ? AND stream_ordering < ?))"
+            )
+            args.extend([topo, topo, stream])
+
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = (
+                "SELECT ts_rank_cd(vector, query) as rank,"
+                " topological_ordering, stream_ordering, room_id, event_id"
+                " FROM plainto_tsquery('english', ?) as query, event_search"
+                " NATURAL JOIN events"
+                " WHERE vector @@ query AND room_id = ?"
+            )
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            sql = (
+                "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
+                " topological_ordering, stream_ordering"
+                " FROM event_search"
+                " NATURAL JOIN events"
+                " WHERE value MATCH ? AND room_id = ?"
+            )
+        else:
+            # This should be unreachable.
+            raise Exception("Unrecognized database engine")
+
+        for clause in clauses:
+            sql += " AND " + clause
+
+        # We add an arbitrary limit here to ensure we don't try to pull the
+        # entire table from the database.
+        sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
+
+        args.append(limit)
+
+        results = yield self._execute(
+            "search_rooms", self.cursor_to_dict, sql, *args
+        )
+
+        events = yield self._get_events([r["event_id"] for r in results])
+
+        event_map = {
+            ev.event_id: ev
+            for ev in events
+        }
+
+        defer.returnValue([
+            {
+                "event": event_map[r["event_id"]],
+                "rank": r["rank"],
+                "pagination_token": "%s,%s" % (
+                    r["topological_ordering"], r["stream_ordering"]
+                ),
+            }
+            for r in results
+            if r["event_id"] in event_map
+        ])