summary refs log tree commit diff
path: root/synapse/handlers/search.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-11-05 14:34:37 +0000
committerErik Johnston <erik@matrix.org>2015-11-05 15:04:08 +0000
commit7301e05122e07f6513916e8a35bf05581de6521d (patch)
tree6a6b196c4cc111abc593e6b472dfe276f0e0e522 /synapse/handlers/search.py
parentImplement order and group by (diff)
downloadsynapse-7301e05122e07f6513916e8a35bf05581de6521d.tar.xz
Implement basic pagination for search results
Diffstat (limited to 'synapse/handlers/search.py')
-rw-r--r--synapse/handlers/search.py78
1 files changed, 65 insertions, 13 deletions
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 28f5300dc9..696780f34e 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -22,6 +22,8 @@ from synapse.api.filtering import Filter
 from synapse.api.errors import SynapseError
 from synapse.events.utils import serialize_event
 
+from unpaddedbase64 import decode_base64, encode_base64
+
 import logging
 
 
@@ -34,17 +36,32 @@ class SearchHandler(BaseHandler):
         super(SearchHandler, self).__init__(hs)
 
     @defer.inlineCallbacks
-    def search(self, user, content):
+    def search(self, user, content, batch=None):
         """Performs a full text search for a user.
 
         Args:
             user (UserID)
             content (dict): Search parameters
+            batch (str): The next_batch parameter. Used for pagination.
 
         Returns:
             dict to be returned to the client with results of search
         """
 
+        batch_group = None
+        batch_group_key = None
+        batch_token = None
+        if batch:
+            try:
+                b = decode_base64(batch)
+                batch_group, batch_group_key, batch_token = b.split("\n")
+
+                assert batch_group is not None
+                assert batch_group_key is not None
+                assert batch_token is not None
+            except:
+                raise SynapseError(400, "Invalid batch")
+
         try:
             room_cat = content["search_categories"]["room_events"]
             search_term = room_cat["search_term"]
@@ -91,17 +108,25 @@ class SearchHandler(BaseHandler):
 
         room_ids = search_filter.filter_rooms(room_ids)
 
+        if batch_group == "room_id":
+            room_ids = room_ids & {batch_group_key}
+
         rank_map = {}
         allowed_events = []
         room_groups = {}
         sender_group = {}
+        global_next_batch = None
 
         if order_by == "rank":
-            rank_map, event_map, _ = yield self.store.search_msgs(
+            results = yield self.store.search_msgs(
                 room_ids, search_term, keys
             )
 
-            filtered_events = search_filter.filter(event_map.values())
+            results_map = {r["event"].event_id: r for r in results}
+
+            rank_map.update({r["event"].event_id: r["rank"] for r in results})
+
+            filtered_events = search_filter.filter([r["event"] for r in results])
 
             events = yield self._filter_events_for_client(
                 user.to_string(), filtered_events
@@ -126,18 +151,26 @@ class SearchHandler(BaseHandler):
         elif order_by == "recent":
             for room_id in room_ids:
                 room_events = []
-                pagination_token = None
+                if batch_group == "room_id" and batch_group_key == room_id:
+                    pagination_token = batch_token
+                else:
+                    pagination_token = None
                 i = 0
 
                 while len(room_events) < search_filter.limit() and i < 5:
                     i += 5
-                    r_map, event_map, pagination_token = yield self.store.search_room(
+                    results = yield self.store.search_room(
                         room_id, search_term, keys, search_filter.limit() * 2,
                         pagination_token=pagination_token,
                     )
-                    rank_map.update(r_map)
 
-                    filtered_events = search_filter.filter(event_map.values())
+                    results_map = {r["event"].event_id: r for r in results}
+
+                    rank_map.update({r["event"].event_id: r["rank"] for r in results})
+
+                    filtered_events = search_filter.filter([
+                        r["event"] for r in results
+                    ])
 
                     events = yield self._filter_events_for_client(
                         user.to_string(), filtered_events
@@ -146,13 +179,26 @@ class SearchHandler(BaseHandler):
                     room_events.extend(events)
                     room_events = room_events[:search_filter.limit()]
 
-                    if len(event_map) < search_filter.limit() * 2:
+                    if len(results) < search_filter.limit() * 2:
+                        pagination_token = None
                         break
+                    else:
+                        pagination_token = results[-1]["pagination_token"]
+
+                if room_events:
+                    res = results_map[room_events[-1].event_id]
+                    pagination_token = res["pagination_token"]
 
                 if room_events:
                     group = room_groups.setdefault(room_id, {})
                     if pagination_token:
-                        group["next_batch"] = pagination_token
+                        next_batch = encode_base64("%s\n%s\n%s" % (
+                            "room_id", room_id, pagination_token
+                        ))
+                        group["next_batch"] = next_batch
+
+                        if batch_token:
+                            global_next_batch = next_batch
 
                     group["results"] = [e.event_id for e in room_events]
                     group["order"] = max(
@@ -164,11 +210,14 @@ class SearchHandler(BaseHandler):
 
             # Normalize the group ranks
             if room_groups:
-                mx = max(g["order"] for g in room_groups.values())
-                mn = min(g["order"] for g in room_groups.values())
+                if len(room_groups) > 1:
+                    mx = max(g["order"] for g in room_groups.values())
+                    mn = min(g["order"] for g in room_groups.values())
 
-                for g in room_groups.values():
-                    g["order"] = (g["order"] - mn) * 1.0 / (mx - mn)
+                    for g in room_groups.values():
+                        g["order"] = (g["order"] - mn) * 1.0 / (mx - mn)
+                else:
+                    room_groups.values()[0]["order"] = 1
 
         else:
             # We should never get here due to the guard earlier.
@@ -239,6 +288,9 @@ class SearchHandler(BaseHandler):
         if sender_group and "sender" in group_keys:
             rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
 
+        if global_next_batch:
+            rooms_cat_res["next_batch"] = global_next_batch
+
         defer.returnValue({
             "search_categories": {
                 "room_events": rooms_cat_res