diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index c39d54a7ca..efd87d99bb 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -162,6 +162,9 @@ class SearchStore(BackgroundUpdateStore):
"(%s)" % (" OR ".join(local_clauses),)
)
+ count_args = args
+ count_clauses = clauses
+
if isinstance(self.database_engine, PostgresEngine):
sql = (
"SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank,"
@@ -170,6 +173,12 @@ class SearchStore(BackgroundUpdateStore):
" WHERE vector @@ to_tsquery('english', ?)"
)
args = [search_query, search_query] + args
+
+ count_sql = (
+ "SELECT room_id, count(*) as count FROM event_search"
+ " WHERE vector @@ to_tsquery('english', ?)"
+ )
+ count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
@@ -177,6 +186,12 @@ class SearchStore(BackgroundUpdateStore):
" WHERE value MATCH ?"
)
args = [search_query] + args
+
+ count_sql = (
+ "SELECT room_id, count(*) as count FROM event_search"
+ " WHERE value MATCH ? AND "
+ )
+ count_args = [search_term] + count_args
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@@ -184,6 +199,9 @@ class SearchStore(BackgroundUpdateStore):
for clause in clauses:
sql += " AND " + clause
+ for clause in count_clauses:
+ count_sql += " AND " + clause
+
# We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
@@ -205,6 +223,14 @@ class SearchStore(BackgroundUpdateStore):
if isinstance(self.database_engine, PostgresEngine):
highlights = yield self._find_highlights_in_postgres(search_query, events)
+ count_sql += " GROUP BY room_id"
+
+ count_results = yield self._execute(
+ "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+ )
+
+ count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+
defer.returnValue({
"results": [
{
@@ -215,6 +241,7 @@ class SearchStore(BackgroundUpdateStore):
if r["event_id"] in event_map
],
"highlights": highlights,
+ "count": count,
})
@defer.inlineCallbacks
@@ -254,6 +281,9 @@ class SearchStore(BackgroundUpdateStore):
"(%s)" % (" OR ".join(local_clauses),)
)
+ count_args = args
+ count_clauses = clauses
+
if pagination_token:
try:
origin_server_ts, stream = pagination_token.split(",")
@@ -276,7 +306,13 @@ class SearchStore(BackgroundUpdateStore):
" NATURAL JOIN events"
" WHERE vector @@ to_tsquery('english', ?) AND "
)
- args = [search_term, search_term] + args
+ args = [search_query, search_query] + args
+
+ count_sql = (
+ "SELECT room_id, count(*) as count FROM event_search"
+ " WHERE vector @@ to_tsquery('english', ?) AND "
+ )
+ count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
# We use CROSS JOIN here to ensure we use the right indexes.
# https://sqlite.org/optoverview.html#crossjoin
@@ -296,12 +332,19 @@ class SearchStore(BackgroundUpdateStore):
" CROSS JOIN events USING (event_id)"
" WHERE "
)
- args = [search_term] + args
+ args = [search_query] + args
+
+ count_sql = (
+ "SELECT room_id, count(*) as count FROM event_search"
+ " WHERE value MATCH ? AND "
+ )
+ count_args = [search_term] + count_args
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
sql += " AND ".join(clauses)
+ count_sql += " AND ".join(count_clauses)
# We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database.
@@ -326,6 +369,14 @@ class SearchStore(BackgroundUpdateStore):
if isinstance(self.database_engine, PostgresEngine):
highlights = yield self._find_highlights_in_postgres(search_query, events)
+ count_sql += " GROUP BY room_id"
+
+ count_results = yield self._execute(
+ "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+ )
+
+ count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+
defer.returnValue({
"results": [
{
@@ -339,6 +390,7 @@ class SearchStore(BackgroundUpdateStore):
if r["event_id"] in event_map
],
"highlights": highlights,
+ "count": count,
})
def _find_highlights_in_postgres(self, search_query, events):
|