summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/search.py78
-rw-r--r--synapse/rest/client/v1/room.py3
-rw-r--r--synapse/storage/search.py55
3 files changed, 86 insertions, 50 deletions
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 28f5300dc9..696780f34e 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -22,6 +22,8 @@ from synapse.api.filtering import Filter
 from synapse.api.errors import SynapseError
 from synapse.events.utils import serialize_event
 
+from unpaddedbase64 import decode_base64, encode_base64
+
 import logging
 
 
@@ -34,17 +36,32 @@ class SearchHandler(BaseHandler):
         super(SearchHandler, self).__init__(hs)
 
     @defer.inlineCallbacks
-    def search(self, user, content):
+    def search(self, user, content, batch=None):
         """Performs a full text search for a user.
 
         Args:
             user (UserID)
             content (dict): Search parameters
+            batch (str): The next_batch parameter. Used for pagination.
 
         Returns:
             dict to be returned to the client with results of search
         """
 
+        batch_group = None
+        batch_group_key = None
+        batch_token = None
+        if batch:
+            try:
+                b = decode_base64(batch)
+                batch_group, batch_group_key, batch_token = b.split("\n")
+
+                assert batch_group is not None
+                assert batch_group_key is not None
+                assert batch_token is not None
+            except:
+                raise SynapseError(400, "Invalid batch")
+
         try:
             room_cat = content["search_categories"]["room_events"]
             search_term = room_cat["search_term"]
@@ -91,17 +108,25 @@ class SearchHandler(BaseHandler):
 
         room_ids = search_filter.filter_rooms(room_ids)
 
+        if batch_group == "room_id":
+            room_ids = room_ids & {batch_group_key}
+
         rank_map = {}
         allowed_events = []
         room_groups = {}
         sender_group = {}
+        global_next_batch = None
 
         if order_by == "rank":
-            rank_map, event_map, _ = yield self.store.search_msgs(
+            results = yield self.store.search_msgs(
                 room_ids, search_term, keys
             )
 
-            filtered_events = search_filter.filter(event_map.values())
+            results_map = {r["event"].event_id: r for r in results}
+
+            rank_map.update({r["event"].event_id: r["rank"] for r in results})
+
+            filtered_events = search_filter.filter([r["event"] for r in results])
 
             events = yield self._filter_events_for_client(
                 user.to_string(), filtered_events
@@ -126,18 +151,26 @@ class SearchHandler(BaseHandler):
         elif order_by == "recent":
             for room_id in room_ids:
                 room_events = []
-                pagination_token = None
+                if batch_group == "room_id" and batch_group_key == room_id:
+                    pagination_token = batch_token
+                else:
+                    pagination_token = None
                 i = 0
 
                 while len(room_events) < search_filter.limit() and i < 5:
                     i += 5
-                    r_map, event_map, pagination_token = yield self.store.search_room(
+                    results = yield self.store.search_room(
                         room_id, search_term, keys, search_filter.limit() * 2,
                         pagination_token=pagination_token,
                     )
-                    rank_map.update(r_map)
 
-                    filtered_events = search_filter.filter(event_map.values())
+                    results_map = {r["event"].event_id: r for r in results}
+
+                    rank_map.update({r["event"].event_id: r["rank"] for r in results})
+
+                    filtered_events = search_filter.filter([
+                        r["event"] for r in results
+                    ])
 
                     events = yield self._filter_events_for_client(
                         user.to_string(), filtered_events
@@ -146,13 +179,26 @@ class SearchHandler(BaseHandler):
                     room_events.extend(events)
                     room_events = room_events[:search_filter.limit()]
 
-                    if len(event_map) < search_filter.limit() * 2:
+                    if len(results) < search_filter.limit() * 2:
+                        pagination_token = None
                         break
+                    else:
+                        pagination_token = results[-1]["pagination_token"]
+
+                if room_events:
+                    res = results_map[room_events[-1].event_id]
+                    pagination_token = res["pagination_token"]
 
                 if room_events:
                     group = room_groups.setdefault(room_id, {})
                     if pagination_token:
-                        group["next_batch"] = pagination_token
+                        next_batch = encode_base64("%s\n%s\n%s" % (
+                            "room_id", room_id, pagination_token
+                        ))
+                        group["next_batch"] = next_batch
+
+                        if batch_token:
+                            global_next_batch = next_batch
 
                     group["results"] = [e.event_id for e in room_events]
                     group["order"] = max(
@@ -164,11 +210,14 @@ class SearchHandler(BaseHandler):
 
             # Normalize the group ranks
             if room_groups:
-                mx = max(g["order"] for g in room_groups.values())
-                mn = min(g["order"] for g in room_groups.values())
+                if len(room_groups) > 1:
+                    mx = max(g["order"] for g in room_groups.values())
+                    mn = min(g["order"] for g in room_groups.values())
 
-                for g in room_groups.values():
-                    g["order"] = (g["order"] - mn) * 1.0 / (mx - mn)
+                    for g in room_groups.values():
+                        g["order"] = (g["order"] - mn) * 1.0 / (mx - mn)
+                else:
+                    room_groups.values()[0]["order"] = 1
 
         else:
             # We should never get here due to the guard earlier.
@@ -239,6 +288,9 @@ class SearchHandler(BaseHandler):
         if sender_group and "sender" in group_keys:
             rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
 
+        if global_next_batch:
+            rooms_cat_res["next_batch"] = global_next_batch
+
         defer.returnValue({
             "search_categories": {
                 "room_events": rooms_cat_res
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 2dcaee86cd..8e28f12d29 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -601,7 +601,8 @@ class SearchRestServlet(ClientV1RestServlet):
 
         content = _parse_json(request)
 
-        results = yield self.handlers.search_handler.search(auth_user, content)
+        batch = request.args.get("next_batch", [None])[0]
+        results = yield self.handlers.search_handler.search(auth_user, content, batch)
 
         defer.returnValue((200, results))
 
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index e37e56c1f2..7342e7bae6 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -18,24 +18,12 @@ from twisted.internet import defer
 from _base import SQLBaseStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
-from collections import namedtuple
-
 import logging
 
 
 logger = logging.getLogger(__name__)
 
 
-"""The result of a search.
-
-Fields:
-    rank_map (dict): Mapping event_id -> rank
-    event_map (dict): Mapping event_id -> event
-    pagination_token (str): Pagination token
-"""
-SearchResult = namedtuple("SearchResult", ("rank_map", "event_map", "pagination_token"))
-
-
 class SearchStore(SQLBaseStore):
     @defer.inlineCallbacks
     def search_msgs(self, room_ids, search_term, keys):
@@ -48,7 +36,7 @@ class SearchStore(SQLBaseStore):
                 "content.body", "content.name", "content.topic"
 
         Returns:
-            SearchResult
+            list of dicts
         """
         clauses = []
         args = []
@@ -106,15 +94,14 @@ class SearchStore(SQLBaseStore):
             for ev in events
         }
 
-        defer.returnValue(SearchResult(
+        defer.returnValue([
             {
-                r["event_id"]: r["rank"]
-                for r in results
-                if r["event_id"] in event_map
-            },
-            event_map,
-            None
-        ))
+                "event": event_map[r["event_id"]],
+                "rank": r["rank"],
+            }
+            for r in results
+            if r["event_id"] in event_map
+        ])
 
     @defer.inlineCallbacks
     def search_room(self, room_id, search_term, keys, limit, pagination_token=None):
@@ -128,7 +115,7 @@ class SearchStore(SQLBaseStore):
             pagination_token (str): A pagination token previously returned
 
         Returns:
-            SearchResult
+            list of dicts
         """
         clauses = []
         args = [search_term, room_id]
@@ -190,18 +177,14 @@ class SearchStore(SQLBaseStore):
             for ev in events
         }
 
-        pagination_token = None
-        if results:
-            topo = results[-1]["topological_ordering"]
-            stream = results[-1]["stream_ordering"]
-            pagination_token = "%s,%s" % (topo, stream)
-
-        defer.returnValue(SearchResult(
+        defer.returnValue([
             {
-                r["event_id"]: r["rank"]
-                for r in results
-                if r["event_id"] in event_map
-            },
-            event_map,
-            pagination_token
-        ))
+                "event": event_map[r["event_id"]],
+                "rank": r["rank"],
+                "pagination_token": "%s,%s" % (
+                    r["topological_ordering"], r["stream_ordering"]
+                ),
+            }
+            for r in results
+            if r["event_id"] in event_map
+        ])