summary refs log tree commit diff
path: root/synapse/storage/search.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/search.py')
-rw-r--r--synapse/storage/search.py35
1 files changed, 26 insertions, 9 deletions
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index a3c69c5ab3..9608b5d6a7 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -18,6 +18,17 @@ from twisted.internet import defer
 from _base import SQLBaseStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
+from collections import namedtuple
+
+"""The result of a search.
+
+Fields:
+    rank_map (dict): Mapping event_id -> rank
+    event_map (dict): Mapping event_id -> event
+    pagination_token (str): Pagination token
+"""
+SearchResult = namedtuple("SearchResult", ("rank_map", "event_map", "pagination_token"))
+
 
 class SearchStore(SQLBaseStore):
     @defer.inlineCallbacks
@@ -31,15 +42,18 @@ class SearchStore(SQLBaseStore):
                 "content.body", "content.name", "content.topic"
 
         Returns:
-            2-tuple of (dict event_id -> rank, dict event_id -> event)
+            SearchResult
         """
         clauses = []
         args = []
 
-        clauses.append(
-            "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
-        )
-        args.extend(room_ids)
+        # Make sure we don't explode because the person is in too many rooms.
+        # We filter the results below regardless.
+        if len(room_ids) < 500:
+            clauses.append(
+                "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
+            )
+            args.extend(room_ids)
 
         local_clauses = []
         for key in keys:
@@ -52,13 +66,13 @@ class SearchStore(SQLBaseStore):
 
         if isinstance(self.database_engine, PostgresEngine):
             sql = (
-                "SELECT ts_rank_cd(vector, query) AS rank, event_id"
+                "SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id"
                 " FROM plainto_tsquery('english', ?) as query, event_search"
                 " WHERE vector @@ query"
             )
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = (
-                "SELECT 0 as rank, event_id FROM event_search"
+                "SELECT 0 as rank, room_id, event_id FROM event_search"
                 " WHERE value MATCH ?"
             )
         else:
@@ -76,6 +90,8 @@ class SearchStore(SQLBaseStore):
             "search_msgs", self.cursor_to_dict, sql, *([search_term] + args)
         )
 
+        results = filter(lambda row: row["room_id"] in room_ids, results)
+
         events = yield self._get_events([r["event_id"] for r in results])
 
         event_map = {
@@ -83,11 +99,12 @@ class SearchStore(SQLBaseStore):
             for ev in events
         }
 
-        defer.returnValue((
+        defer.returnValue(SearchResult(
             {
                 r["event_id"]: r["rank"]
                 for r in results
                 if r["event_id"] in event_map
             },
-            event_map
+            event_map,
+            None
         ))