diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 28f5300dc9..696780f34e 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -22,6 +22,8 @@ from synapse.api.filtering import Filter
from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event
+from unpaddedbase64 import decode_base64, encode_base64
+
import logging
@@ -34,17 +36,32 @@ class SearchHandler(BaseHandler):
super(SearchHandler, self).__init__(hs)
@defer.inlineCallbacks
- def search(self, user, content):
+ def search(self, user, content, batch=None):
"""Performs a full text search for a user.
Args:
user (UserID)
content (dict): Search parameters
+ batch (str): The next_batch parameter. Used for pagination.
Returns:
dict to be returned to the client with results of search
"""
+ batch_group = None
+ batch_group_key = None
+ batch_token = None
+ if batch:
+ try:
+ b = decode_base64(batch)
+ batch_group, batch_group_key, batch_token = b.split("\n")
+
+ assert batch_group is not None
+ assert batch_group_key is not None
+ assert batch_token is not None
+ except:
+ raise SynapseError(400, "Invalid batch")
+
try:
room_cat = content["search_categories"]["room_events"]
search_term = room_cat["search_term"]
@@ -91,17 +108,25 @@ class SearchHandler(BaseHandler):
room_ids = search_filter.filter_rooms(room_ids)
+ if batch_group == "room_id":
+ room_ids = room_ids & {batch_group_key}
+
rank_map = {}
allowed_events = []
room_groups = {}
sender_group = {}
+ global_next_batch = None
if order_by == "rank":
- rank_map, event_map, _ = yield self.store.search_msgs(
+ results = yield self.store.search_msgs(
room_ids, search_term, keys
)
- filtered_events = search_filter.filter(event_map.values())
+ results_map = {r["event"].event_id: r for r in results}
+
+ rank_map.update({r["event"].event_id: r["rank"] for r in results})
+
+ filtered_events = search_filter.filter([r["event"] for r in results])
events = yield self._filter_events_for_client(
user.to_string(), filtered_events
@@ -126,18 +151,26 @@ class SearchHandler(BaseHandler):
elif order_by == "recent":
for room_id in room_ids:
room_events = []
- pagination_token = None
+ if batch_group == "room_id" and batch_group_key == room_id:
+ pagination_token = batch_token
+ else:
+ pagination_token = None
i = 0
while len(room_events) < search_filter.limit() and i < 5:
i += 5
- r_map, event_map, pagination_token = yield self.store.search_room(
+ results = yield self.store.search_room(
room_id, search_term, keys, search_filter.limit() * 2,
pagination_token=pagination_token,
)
- rank_map.update(r_map)
- filtered_events = search_filter.filter(event_map.values())
+ results_map = {r["event"].event_id: r for r in results}
+
+ rank_map.update({r["event"].event_id: r["rank"] for r in results})
+
+ filtered_events = search_filter.filter([
+ r["event"] for r in results
+ ])
events = yield self._filter_events_for_client(
user.to_string(), filtered_events
@@ -146,13 +179,26 @@ class SearchHandler(BaseHandler):
room_events.extend(events)
room_events = room_events[:search_filter.limit()]
- if len(event_map) < search_filter.limit() * 2:
+ if len(results) < search_filter.limit() * 2:
+ pagination_token = None
break
+ else:
+ pagination_token = results[-1]["pagination_token"]
+
+ if room_events:
+ res = results_map[room_events[-1].event_id]
+ pagination_token = res["pagination_token"]
if room_events:
group = room_groups.setdefault(room_id, {})
if pagination_token:
- group["next_batch"] = pagination_token
+ next_batch = encode_base64("%s\n%s\n%s" % (
+ "room_id", room_id, pagination_token
+ ))
+ group["next_batch"] = next_batch
+
+ if batch_token:
+ global_next_batch = next_batch
group["results"] = [e.event_id for e in room_events]
group["order"] = max(
@@ -164,11 +210,14 @@ class SearchHandler(BaseHandler):
# Normalize the group ranks
if room_groups:
- mx = max(g["order"] for g in room_groups.values())
- mn = min(g["order"] for g in room_groups.values())
+ if len(room_groups) > 1:
+ mx = max(g["order"] for g in room_groups.values())
+ mn = min(g["order"] for g in room_groups.values())
- for g in room_groups.values():
- g["order"] = (g["order"] - mn) * 1.0 / (mx - mn)
+ for g in room_groups.values():
+ g["order"] = (g["order"] - mn) * 1.0 / (mx - mn)
+ else:
+ room_groups.values()[0]["order"] = 1
else:
# We should never get here due to the guard earlier.
@@ -239,6 +288,9 @@ class SearchHandler(BaseHandler):
if sender_group and "sender" in group_keys:
rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
+ if global_next_batch:
+ rooms_cat_res["next_batch"] = global_next_batch
+
defer.returnValue({
"search_categories": {
"room_events": rooms_cat_res
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 2dcaee86cd..8e28f12d29 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -601,7 +601,8 @@ class SearchRestServlet(ClientV1RestServlet):
content = _parse_json(request)
- results = yield self.handlers.search_handler.search(auth_user, content)
+ batch = request.args.get("next_batch", [None])[0]
+ results = yield self.handlers.search_handler.search(auth_user, content, batch)
defer.returnValue((200, results))
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index e37e56c1f2..7342e7bae6 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -18,24 +18,12 @@ from twisted.internet import defer
from _base import SQLBaseStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from collections import namedtuple
-
import logging
logger = logging.getLogger(__name__)
-"""The result of a search.
-
-Fields:
- rank_map (dict): Mapping event_id -> rank
- event_map (dict): Mapping event_id -> event
- pagination_token (str): Pagination token
-"""
-SearchResult = namedtuple("SearchResult", ("rank_map", "event_map", "pagination_token"))
-
-
class SearchStore(SQLBaseStore):
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
@@ -48,7 +36,7 @@ class SearchStore(SQLBaseStore):
"content.body", "content.name", "content.topic"
Returns:
- SearchResult
+ list of dicts
"""
clauses = []
args = []
@@ -106,15 +94,14 @@ class SearchStore(SQLBaseStore):
for ev in events
}
- defer.returnValue(SearchResult(
+ defer.returnValue([
{
- r["event_id"]: r["rank"]
- for r in results
- if r["event_id"] in event_map
- },
- event_map,
- None
- ))
+ "event": event_map[r["event_id"]],
+ "rank": r["rank"],
+ }
+ for r in results
+ if r["event_id"] in event_map
+ ])
@defer.inlineCallbacks
def search_room(self, room_id, search_term, keys, limit, pagination_token=None):
@@ -128,7 +115,7 @@ class SearchStore(SQLBaseStore):
pagination_token (str): A pagination token previously returned
Returns:
- SearchResult
+ list of dicts
"""
clauses = []
args = [search_term, room_id]
@@ -190,18 +177,14 @@ class SearchStore(SQLBaseStore):
for ev in events
}
- pagination_token = None
- if results:
- topo = results[-1]["topological_ordering"]
- stream = results[-1]["stream_ordering"]
- pagination_token = "%s,%s" % (topo, stream)
-
- defer.returnValue(SearchResult(
+ defer.returnValue([
{
- r["event_id"]: r["rank"]
- for r in results
- if r["event_id"] in event_map
- },
- event_map,
- pagination_token
- ))
+ "event": event_map[r["event_id"]],
+ "rank": r["rank"],
+ "pagination_token": "%s,%s" % (
+ r["topological_ordering"], r["stream_ordering"]
+ ),
+ }
+ for r in results
+ if r["event_id"] in event_map
+ ])
|