summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/search.py126
-rw-r--r--synapse/storage/search.py96
2 files changed, 205 insertions, 17 deletions
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 2718e9482e..28f5300dc9 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -46,15 +46,20 @@ class SearchHandler(BaseHandler):
         """
 
         try:
-            search_term = content["search_categories"]["room_events"]["search_term"]
-            keys = content["search_categories"]["room_events"].get("keys", [
+            room_cat = content["search_categories"]["room_events"]
+            search_term = room_cat["search_term"]
+            keys = room_cat.get("keys", [
                 "content.body", "content.name", "content.topic",
             ])
-            filter_dict = content["search_categories"]["room_events"].get("filter", {})
-            event_context = content["search_categories"]["room_events"].get(
+            filter_dict = room_cat.get("filter", {})
+            order_by = room_cat.get("order_by", "rank")
+            event_context = room_cat.get(
                 "event_context", None
             )
 
+            group_by = room_cat.get("groupings", {}).get("group_by", {})
+            group_keys = [g["key"] for g in group_by]
+
             if event_context is not None:
                 before_limit = int(event_context.get(
                     "before_limit", 5
@@ -65,6 +70,15 @@ class SearchHandler(BaseHandler):
         except KeyError:
             raise SynapseError(400, "Invalid search query")
 
+        if order_by not in ("rank", "recent"):
+            raise SynapseError(400, "Invalid order by: %r" % (order_by,))
+
+        if set(group_keys) - {"room_id", "sender"}:
+            raise SynapseError(
+                400,
+                "Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},)
+            )
+
         search_filter = Filter(filter_dict)
 
         # TODO: Search through left rooms too
@@ -77,18 +91,88 @@ class SearchHandler(BaseHandler):
 
         room_ids = search_filter.filter_rooms(room_ids)
 
-        rank_map, event_map, _ = yield self.store.search_msgs(
-            room_ids, search_term, keys
-        )
+        rank_map = {}
+        allowed_events = []
+        room_groups = {}
+        sender_group = {}
 
-        filtered_events = search_filter.filter(event_map.values())
+        if order_by == "rank":
+            rank_map, event_map, _ = yield self.store.search_msgs(
+                room_ids, search_term, keys
+            )
 
-        allowed_events = yield self._filter_events_for_client(
-            user.to_string(), filtered_events
-        )
+            filtered_events = search_filter.filter(event_map.values())
 
-        allowed_events.sort(key=lambda e: -rank_map[e.event_id])
-        allowed_events = allowed_events[:search_filter.limit()]
+            events = yield self._filter_events_for_client(
+                user.to_string(), filtered_events
+            )
+
+            events.sort(key=lambda e: -rank_map[e.event_id])
+            allowed_events = events[:search_filter.limit()]
+
+            for e in allowed_events:
+                rm = room_groups.setdefault(e.room_id, {
+                    "results": [],
+                    "order": rank_map[e.event_id],
+                })
+                rm["results"].append(e.event_id)
+
+                s = sender_group.setdefault(e.sender, {
+                    "results": [],
+                    "order": rank_map[e.event_id],
+                })
+                s["results"].append(e.event_id)
+
+        elif order_by == "recent":
+            for room_id in room_ids:
+                room_events = []
+                pagination_token = None
+                i = 0
+
+                while len(room_events) < search_filter.limit() and i < 5:
+                    i += 5
+                    r_map, event_map, pagination_token = yield self.store.search_room(
+                        room_id, search_term, keys, search_filter.limit() * 2,
+                        pagination_token=pagination_token,
+                    )
+                    rank_map.update(r_map)
+
+                    filtered_events = search_filter.filter(event_map.values())
+
+                    events = yield self._filter_events_for_client(
+                        user.to_string(), filtered_events
+                    )
+
+                    room_events.extend(events)
+                    room_events = room_events[:search_filter.limit()]
+
+                    if len(event_map) < search_filter.limit() * 2:
+                        break
+
+                if room_events:
+                    group = room_groups.setdefault(room_id, {})
+                    if pagination_token:
+                        group["next_batch"] = pagination_token
+
+                    group["results"] = [e.event_id for e in room_events]
+                    group["order"] = max(
+                        e.origin_server_ts/1000 for e in room_events
+                        if hasattr(e, "origin_server_ts")
+                    )
+
+                allowed_events.extend(room_events)
+
+            # Normalize the group ranks
+            if room_groups:
+                mx = max(g["order"] for g in room_groups.values())
+                mn = min(g["order"] for g in room_groups.values())
+
+                for g in room_groups.values():
+                    g["order"] = (g["order"] - mn) * 1.0 / (mx - mn)
+
+        else:
+            # We should never get here due to the guard earlier.
+            raise NotImplementedError()
 
         if event_context is not None:
             now_token = yield self.hs.get_event_sources().get_current_token()
@@ -144,11 +228,19 @@ class SearchHandler(BaseHandler):
 
         logger.info("Found %d results", len(results))
 
+        rooms_cat_res = {
+            "results": results,
+            "count": len(results)
+        }
+
+        if room_groups and "room_id" in group_keys:
+            rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups
+
+        if sender_group and "sender" in group_keys:
+            rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
+
         defer.returnValue({
             "search_categories": {
-                "room_events": {
-                    "results": results,
-                    "count": len(results)
-                }
+                "room_events": rooms_cat_res
             }
         })
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index cdf003502f..e37e56c1f2 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -20,6 +20,12 @@ from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
 from collections import namedtuple
 
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
 """The result of a search.
 
 Fields:
@@ -109,3 +115,93 @@ class SearchStore(SQLBaseStore):
             event_map,
             None
         ))
+
+    @defer.inlineCallbacks
+    def search_room(self, room_id, search_term, keys, limit, pagination_token=None):
+        """Performs a full text search over events with given keys.
+
+        Args:
+            room_id (str): The room_id to search in
+            search_term (str): Search term to search for
+            keys (list): List of keys to search in, currently supports
+                "content.body", "content.name", "content.topic"
+            pagination_token (str): A pagination token previously returned
+
+        Returns:
+            SearchResult
+        """
+        clauses = []
+        args = [search_term, room_id]
+
+        local_clauses = []
+        for key in keys:
+            local_clauses.append("key = ?")
+            args.append(key)
+
+        clauses.append(
+            "(%s)" % (" OR ".join(local_clauses),)
+        )
+
+        if pagination_token:
+            topo, stream = pagination_token.split(",")
+            clauses.append(
+                "(topological_ordering < ?"
+                " OR (topological_ordering = ? AND stream_ordering < ?))"
+            )
+            args.extend([topo, topo, stream])
+
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = (
+                "SELECT ts_rank_cd(vector, query) as rank,"
+                " topological_ordering, stream_ordering, room_id, event_id"
+                " FROM plainto_tsquery('english', ?) as query, event_search"
+                " NATURAL JOIN events"
+                " WHERE vector @@ query AND room_id = ?"
+            )
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            sql = (
+                "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
+                " topological_ordering, stream_ordering"
+                " FROM event_search"
+                " NATURAL JOIN events"
+                " WHERE value MATCH ? AND room_id = ?"
+            )
+        else:
+            # This should be unreachable.
+            raise Exception("Unrecognized database engine")
+
+        for clause in clauses:
+            sql += " AND " + clause
+
+        # We add an arbitrary limit here to ensure we don't try to pull the
+        # entire table from the database.
+        sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
+
+        args.append(limit)
+
+        results = yield self._execute(
+            "search_rooms", self.cursor_to_dict, sql, *args
+        )
+
+        events = yield self._get_events([r["event_id"] for r in results])
+
+        event_map = {
+            ev.event_id: ev
+            for ev in events
+        }
+
+        pagination_token = None
+        if results:
+            topo = results[-1]["topological_ordering"]
+            stream = results[-1]["stream_ordering"]
+            pagination_token = "%s,%s" % (topo, stream)
+
+        defer.returnValue(SearchResult(
+            {
+                r["event_id"]: r["rank"]
+                for r in results
+                if r["event_id"] in event_map
+            },
+            event_map,
+            pagination_token
+        ))