summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/search.py8
-rw-r--r--synapse/storage/search.py56
2 files changed, 61 insertions, 3 deletions
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index bc79564287..99ef56871c 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -152,11 +152,15 @@ class SearchHandler(BaseHandler):
 
         highlights = set()
 
+        count = None
+
         if order_by == "rank":
             search_result = yield self.store.search_msgs(
                 room_ids, search_term, keys
             )
 
+            count = search_result["count"]
+
             if search_result["highlights"]:
                 highlights.update(search_result["highlights"])
 
@@ -207,6 +211,8 @@ class SearchHandler(BaseHandler):
                 if search_result["highlights"]:
                     highlights.update(search_result["highlights"])
 
+                count = search_result["count"]
+
                 results = search_result["results"]
 
                 results_map = {r["event"].event_id: r for r in results}
@@ -359,7 +365,7 @@ class SearchHandler(BaseHandler):
 
         rooms_cat_res = {
             "results": results,
-            "count": len(results),
+            "count": count,
             "highlights": list(highlights),
         }
 
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index c39d54a7ca..efd87d99bb 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -162,6 +162,9 @@ class SearchStore(BackgroundUpdateStore):
             "(%s)" % (" OR ".join(local_clauses),)
         )
 
+        count_args = args
+        count_clauses = clauses
+
         if isinstance(self.database_engine, PostgresEngine):
             sql = (
                 "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank,"
@@ -170,6 +173,12 @@ class SearchStore(BackgroundUpdateStore):
                 " WHERE vector @@ to_tsquery('english', ?)"
             )
             args = [search_query, search_query] + args
+
+            count_sql = (
+                "SELECT room_id, count(*) as count FROM event_search"
+                " WHERE vector @@ to_tsquery('english', ?)"
+            )
+            count_args = [search_query] + count_args
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = (
                 "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
@@ -177,6 +186,12 @@ class SearchStore(BackgroundUpdateStore):
                 " WHERE value MATCH ?"
             )
             args = [search_query] + args
+
+            count_sql = (
+                "SELECT room_id, count(*) as count FROM event_search"
+                " WHERE value MATCH ? AND "
+            )
+            count_args = [search_term] + count_args
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
@@ -184,6 +199,9 @@ class SearchStore(BackgroundUpdateStore):
         for clause in clauses:
             sql += " AND " + clause
 
+        for clause in count_clauses:
+            count_sql += " AND " + clause
+
         # We add an arbitrary limit here to ensure we don't try to pull the
         # entire table from the database.
         sql += " ORDER BY rank DESC LIMIT 500"
@@ -205,6 +223,14 @@ class SearchStore(BackgroundUpdateStore):
         if isinstance(self.database_engine, PostgresEngine):
             highlights = yield self._find_highlights_in_postgres(search_query, events)
 
+        count_sql += " GROUP BY room_id"
+
+        count_results = yield self._execute(
+            "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+        )
+
+        count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+
         defer.returnValue({
             "results": [
                 {
@@ -215,6 +241,7 @@ class SearchStore(BackgroundUpdateStore):
                 if r["event_id"] in event_map
             ],
             "highlights": highlights,
+            "count": count,
         })
 
     @defer.inlineCallbacks
@@ -254,6 +281,9 @@ class SearchStore(BackgroundUpdateStore):
             "(%s)" % (" OR ".join(local_clauses),)
         )
 
+        count_args = args
+        count_clauses = clauses
+
         if pagination_token:
             try:
                 origin_server_ts, stream = pagination_token.split(",")
@@ -276,7 +306,13 @@ class SearchStore(BackgroundUpdateStore):
                 " NATURAL JOIN events"
                 " WHERE vector @@ to_tsquery('english', ?) AND "
             )
-            args = [search_term, search_term] + args
+            args = [search_query, search_query] + args
+
+            count_sql = (
+                "SELECT room_id, count(*) as count FROM event_search"
+                " WHERE vector @@ to_tsquery('english', ?) AND "
+            )
+            count_args = [search_query] + count_args
         elif isinstance(self.database_engine, Sqlite3Engine):
             # We use CROSS JOIN here to ensure we use the right indexes.
             # https://sqlite.org/optoverview.html#crossjoin
@@ -296,12 +332,19 @@ class SearchStore(BackgroundUpdateStore):
                 " CROSS JOIN events USING (event_id)"
                 " WHERE "
             )
-            args = [search_term] + args
+            args = [search_query] + args
+
+            count_sql = (
+                "SELECT room_id, count(*) as count FROM event_search"
+                " WHERE value MATCH ? AND "
+            )
+            count_args = [search_term] + count_args
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
 
         sql += " AND ".join(clauses)
+        count_sql += " AND ".join(count_clauses)
 
         # We add an arbitrary limit here to ensure we don't try to pull the
         # entire table from the database.
@@ -326,6 +369,14 @@ class SearchStore(BackgroundUpdateStore):
         if isinstance(self.database_engine, PostgresEngine):
             highlights = yield self._find_highlights_in_postgres(search_query, events)
 
+        count_sql += " GROUP BY room_id"
+
+        count_results = yield self._execute(
+            "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+        )
+
+        count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+
         defer.returnValue({
             "results": [
                 {
@@ -339,6 +390,7 @@ class SearchStore(BackgroundUpdateStore):
                 if r["event_id"] in event_map
             ],
             "highlights": highlights,
+            "count": count,
         })
 
     def _find_highlights_in_postgres(self, search_query, events):