summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/search.py180
-rw-r--r--synapse/rest/client/v1/room.py3
-rw-r--r--synapse/storage/search.py113
3 files changed, 260 insertions, 36 deletions
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 2718e9482e..696780f34e 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -22,6 +22,8 @@ from synapse.api.filtering import Filter
 from synapse.api.errors import SynapseError
 from synapse.events.utils import serialize_event
 
+from unpaddedbase64 import decode_base64, encode_base64
+
 import logging
 
 
@@ -34,27 +36,47 @@ class SearchHandler(BaseHandler):
         super(SearchHandler, self).__init__(hs)
 
     @defer.inlineCallbacks
-    def search(self, user, content):
+    def search(self, user, content, batch=None):
         """Performs a full text search for a user.
 
         Args:
             user (UserID)
             content (dict): Search parameters
+            batch (str): The next_batch parameter. Used for pagination.
 
         Returns:
             dict to be returned to the client with results of search
         """
 
+        batch_group = None
+        batch_group_key = None
+        batch_token = None
+        if batch:
+            try:
+                b = decode_base64(batch)
+                batch_group, batch_group_key, batch_token = b.split("\n")
+
+                assert batch_group is not None
+                assert batch_group_key is not None
+                assert batch_token is not None
+            except:
+                raise SynapseError(400, "Invalid batch")
+
         try:
-            search_term = content["search_categories"]["room_events"]["search_term"]
-            keys = content["search_categories"]["room_events"].get("keys", [
+            room_cat = content["search_categories"]["room_events"]
+            search_term = room_cat["search_term"]
+            keys = room_cat.get("keys", [
                 "content.body", "content.name", "content.topic",
             ])
-            filter_dict = content["search_categories"]["room_events"].get("filter", {})
-            event_context = content["search_categories"]["room_events"].get(
+            filter_dict = room_cat.get("filter", {})
+            order_by = room_cat.get("order_by", "rank")
+            event_context = room_cat.get(
                 "event_context", None
             )
 
+            group_by = room_cat.get("groupings", {}).get("group_by", {})
+            group_keys = [g["key"] for g in group_by]
+
             if event_context is not None:
                 before_limit = int(event_context.get(
                     "before_limit", 5
@@ -65,6 +87,15 @@ class SearchHandler(BaseHandler):
         except KeyError:
             raise SynapseError(400, "Invalid search query")
 
+        if order_by not in ("rank", "recent"):
+            raise SynapseError(400, "Invalid order by: %r" % (order_by,))
+
+        if set(group_keys) - {"room_id", "sender"}:
+            raise SynapseError(
+                400,
+                "Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},)
+            )
+
         search_filter = Filter(filter_dict)
 
         # TODO: Search through left rooms too
@@ -77,18 +108,120 @@ class SearchHandler(BaseHandler):
 
         room_ids = search_filter.filter_rooms(room_ids)
 
-        rank_map, event_map, _ = yield self.store.search_msgs(
-            room_ids, search_term, keys
-        )
+        if batch_group == "room_id":
+            room_ids = room_ids & {batch_group_key}
 
-        filtered_events = search_filter.filter(event_map.values())
+        rank_map = {}
+        allowed_events = []
+        room_groups = {}
+        sender_group = {}
+        global_next_batch = None
 
-        allowed_events = yield self._filter_events_for_client(
-            user.to_string(), filtered_events
-        )
+        if order_by == "rank":
+            results = yield self.store.search_msgs(
+                room_ids, search_term, keys
+            )
+
+            results_map = {r["event"].event_id: r for r in results}
 
-        allowed_events.sort(key=lambda e: -rank_map[e.event_id])
-        allowed_events = allowed_events[:search_filter.limit()]
+            rank_map.update({r["event"].event_id: r["rank"] for r in results})
+
+            filtered_events = search_filter.filter([r["event"] for r in results])
+
+            events = yield self._filter_events_for_client(
+                user.to_string(), filtered_events
+            )
+
+            events.sort(key=lambda e: -rank_map[e.event_id])
+            allowed_events = events[:search_filter.limit()]
+
+            for e in allowed_events:
+                rm = room_groups.setdefault(e.room_id, {
+                    "results": [],
+                    "order": rank_map[e.event_id],
+                })
+                rm["results"].append(e.event_id)
+
+                s = sender_group.setdefault(e.sender, {
+                    "results": [],
+                    "order": rank_map[e.event_id],
+                })
+                s["results"].append(e.event_id)
+
+        elif order_by == "recent":
+            for room_id in room_ids:
+                room_events = []
+                if batch_group == "room_id" and batch_group_key == room_id:
+                    pagination_token = batch_token
+                else:
+                    pagination_token = None
+                i = 0
+
+                while len(room_events) < search_filter.limit() and i < 5:
+                    i += 5
+                    results = yield self.store.search_room(
+                        room_id, search_term, keys, search_filter.limit() * 2,
+                        pagination_token=pagination_token,
+                    )
+
+                    results_map = {r["event"].event_id: r for r in results}
+
+                    rank_map.update({r["event"].event_id: r["rank"] for r in results})
+
+                    filtered_events = search_filter.filter([
+                        r["event"] for r in results
+                    ])
+
+                    events = yield self._filter_events_for_client(
+                        user.to_string(), filtered_events
+                    )
+
+                    room_events.extend(events)
+                    room_events = room_events[:search_filter.limit()]
+
+                    if len(results) < search_filter.limit() * 2:
+                        pagination_token = None
+                        break
+                    else:
+                        pagination_token = results[-1]["pagination_token"]
+
+                if room_events:
+                    res = results_map[room_events[-1].event_id]
+                    pagination_token = res["pagination_token"]
+
+                if room_events:
+                    group = room_groups.setdefault(room_id, {})
+                    if pagination_token:
+                        next_batch = encode_base64("%s\n%s\n%s" % (
+                            "room_id", room_id, pagination_token
+                        ))
+                        group["next_batch"] = next_batch
+
+                        if batch_token:
+                            global_next_batch = next_batch
+
+                    group["results"] = [e.event_id for e in room_events]
+                    group["order"] = max(
+                        e.origin_server_ts/1000 for e in room_events
+                        if hasattr(e, "origin_server_ts")
+                    )
+
+                allowed_events.extend(room_events)
+
+            # Normalize the group ranks
+            if room_groups:
+                if len(room_groups) > 1:
+                    mx = max(g["order"] for g in room_groups.values())
+                    mn = min(g["order"] for g in room_groups.values())
+
+                    for g in room_groups.values():
+                        g["order"] = (g["order"] - mn) * 1.0 / (mx - mn)
+                else:
+                    room_groups.values()[0]["order"] = 1
+
+        else:
+            # We should never get here due to the guard earlier.
+            raise NotImplementedError()
 
         if event_context is not None:
             now_token = yield self.hs.get_event_sources().get_current_token()
@@ -144,11 +277,22 @@ class SearchHandler(BaseHandler):
 
         logger.info("Found %d results", len(results))
 
+        rooms_cat_res = {
+            "results": results,
+            "count": len(results)
+        }
+
+        if room_groups and "room_id" in group_keys:
+            rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups
+
+        if sender_group and "sender" in group_keys:
+            rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
+
+        if global_next_batch:
+            rooms_cat_res["next_batch"] = global_next_batch
+
         defer.returnValue({
             "search_categories": {
-                "room_events": {
-                    "results": results,
-                    "count": len(results)
-                }
+                "room_events": rooms_cat_res
             }
         })
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index afb802baec..b1ea60eb5d 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -604,7 +604,8 @@ class SearchRestServlet(ClientV1RestServlet):
 
         content = _parse_json(request)
 
-        results = yield self.handlers.search_handler.search(auth_user, content)
+        batch = request.args.get("next_batch", [None])[0]
+        results = yield self.handlers.search_handler.search(auth_user, content, batch)
 
         defer.returnValue((200, results))
 
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index cdf003502f..7342e7bae6 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -18,16 +18,10 @@ from twisted.internet import defer
 from _base import SQLBaseStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
-from collections import namedtuple
+import logging
 
-"""The result of a search.
 
-Fields:
-    rank_map (dict): Mapping event_id -> rank
-    event_map (dict): Mapping event_id -> event
-    pagination_token (str): Pagination token
-"""
-SearchResult = namedtuple("SearchResult", ("rank_map", "event_map", "pagination_token"))
+logger = logging.getLogger(__name__)
 
 
 class SearchStore(SQLBaseStore):
@@ -42,7 +36,7 @@ class SearchStore(SQLBaseStore):
                 "content.body", "content.name", "content.topic"
 
         Returns:
-            SearchResult
+            list of dicts
         """
         clauses = []
         args = []
@@ -100,12 +94,97 @@ class SearchStore(SQLBaseStore):
             for ev in events
         }
 
-        defer.returnValue(SearchResult(
+        defer.returnValue([
             {
-                r["event_id"]: r["rank"]
-                for r in results
-                if r["event_id"] in event_map
-            },
-            event_map,
-            None
-        ))
+                "event": event_map[r["event_id"]],
+                "rank": r["rank"],
+            }
+            for r in results
+            if r["event_id"] in event_map
+        ])
+
+    @defer.inlineCallbacks
+    def search_room(self, room_id, search_term, keys, limit, pagination_token=None):
+        """Performs a full text search over events with given keys.
+
+        Args:
+            room_id (str): The room_id to search in
+            search_term (str): Search term to search for
+            keys (list): List of keys to search in, currently supports
+                "content.body", "content.name", "content.topic"
+            pagination_token (str): A pagination token previously returned
+
+        Returns:
+            list of dicts
+        """
+        clauses = []
+        args = [search_term, room_id]
+
+        local_clauses = []
+        for key in keys:
+            local_clauses.append("key = ?")
+            args.append(key)
+
+        clauses.append(
+            "(%s)" % (" OR ".join(local_clauses),)
+        )
+
+        if pagination_token:
+            topo, stream = pagination_token.split(",")
+            clauses.append(
+                "(topological_ordering < ?"
+                " OR (topological_ordering = ? AND stream_ordering < ?))"
+            )
+            args.extend([topo, topo, stream])
+
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = (
+                "SELECT ts_rank_cd(vector, query) as rank,"
+                " topological_ordering, stream_ordering, room_id, event_id"
+                " FROM plainto_tsquery('english', ?) as query, event_search"
+                " NATURAL JOIN events"
+                " WHERE vector @@ query AND room_id = ?"
+            )
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            sql = (
+                "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
+                " topological_ordering, stream_ordering"
+                " FROM event_search"
+                " NATURAL JOIN events"
+                " WHERE value MATCH ? AND room_id = ?"
+            )
+        else:
+            # This should be unreachable.
+            raise Exception("Unrecognized database engine")
+
+        for clause in clauses:
+            sql += " AND " + clause
+
+        # We add an arbitrary limit here to ensure we don't try to pull the
+        # entire table from the database.
+        sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
+
+        args.append(limit)
+
+        results = yield self._execute(
+            "search_rooms", self.cursor_to_dict, sql, *args
+        )
+
+        events = yield self._get_events([r["event_id"] for r in results])
+
+        event_map = {
+            ev.event_id: ev
+            for ev in events
+        }
+
+        defer.returnValue([
+            {
+                "event": event_map[r["event_id"]],
+                "rank": r["rank"],
+                "pagination_token": "%s,%s" % (
+                    r["topological_ordering"], r["stream_ordering"]
+                ),
+            }
+            for r in results
+            if r["event_id"] in event_map
+        ])