summary refs log tree commit diff
path: root/synapse/handlers/search.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/search.py')
-rw-r--r--synapse/handlers/search.py319
1 files changed, 319 insertions, 0 deletions
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
new file mode 100644
index 0000000000..b7545c111f
--- /dev/null
+++ b/synapse/handlers/search.py
@@ -0,0 +1,319 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from ._base import BaseHandler
+
+from synapse.api.constants import Membership
+from synapse.api.filtering import Filter
+from synapse.api.errors import SynapseError
+from synapse.events.utils import serialize_event
+
+from unpaddedbase64 import decode_base64, encode_base64
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class SearchHandler(BaseHandler):
+
+    def __init__(self, hs):
+        super(SearchHandler, self).__init__(hs)
+
+    @defer.inlineCallbacks
+    def search(self, user, content, batch=None):
+        """Performs a full text search for a user.
+
+        Args:
+            user (UserID)
+            content (dict): Search parameters
+            batch (str): The next_batch parameter. Used for pagination.
+
+        Returns:
+            dict to be returned to the client with results of search
+        """
+
+        batch_group = None
+        batch_group_key = None
+        batch_token = None
+        if batch:
+            try:
+                b = decode_base64(batch)
+                batch_group, batch_group_key, batch_token = b.split("\n")
+
+                assert batch_group is not None
+                assert batch_group_key is not None
+                assert batch_token is not None
+            except:
+                raise SynapseError(400, "Invalid batch")
+
+        try:
+            room_cat = content["search_categories"]["room_events"]
+
+            # The actual thing to query in FTS
+            search_term = room_cat["search_term"]
+
+            # Which "keys" to search over in FTS query
+            keys = room_cat.get("keys", [
+                "content.body", "content.name", "content.topic",
+            ])
+
+            # Filter to apply to results
+            filter_dict = room_cat.get("filter", {})
+
+            # What to order results by (impacts whether pagination can be doen)
+            order_by = room_cat.get("order_by", "rank")
+
+            # Include context around each event?
+            event_context = room_cat.get(
+                "event_context", None
+            )
+
+            # Group results together? May allow clients to paginate within a
+            # group
+            group_by = room_cat.get("groupings", {}).get("group_by", {})
+            group_keys = [g["key"] for g in group_by]
+
+            if event_context is not None:
+                before_limit = int(event_context.get(
+                    "before_limit", 5
+                ))
+                after_limit = int(event_context.get(
+                    "after_limit", 5
+                ))
+        except KeyError:
+            raise SynapseError(400, "Invalid search query")
+
+        if order_by not in ("rank", "recent"):
+            raise SynapseError(400, "Invalid order by: %r" % (order_by,))
+
+        if set(group_keys) - {"room_id", "sender"}:
+            raise SynapseError(
+                400,
+                "Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},)
+            )
+
+        search_filter = Filter(filter_dict)
+
+        # TODO: Search through left rooms too
+        rooms = yield self.store.get_rooms_for_user_where_membership_is(
+            user.to_string(),
+            membership_list=[Membership.JOIN],
+            # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
+        )
+        room_ids = set(r.room_id for r in rooms)
+
+        room_ids = search_filter.filter_rooms(room_ids)
+
+        if batch_group == "room_id":
+            room_ids.intersection_update({batch_group_key})
+
+        rank_map = {}  # event_id -> rank of event
+        allowed_events = []
+        room_groups = {}  # Holds result of grouping by room, if applicable
+        sender_group = {}  # Holds result of grouping by sender, if applicable
+
+        # Holds the next_batch for the entire result set if one of those exists
+        global_next_batch = None
+
+        if order_by == "rank":
+            results = yield self.store.search_msgs(
+                room_ids, search_term, keys
+            )
+
+            results_map = {r["event"].event_id: r for r in results}
+
+            rank_map.update({r["event"].event_id: r["rank"] for r in results})
+
+            filtered_events = search_filter.filter([r["event"] for r in results])
+
+            events = yield self._filter_events_for_client(
+                user.to_string(), filtered_events
+            )
+
+            events.sort(key=lambda e: -rank_map[e.event_id])
+            allowed_events = events[:search_filter.limit()]
+
+            for e in allowed_events:
+                rm = room_groups.setdefault(e.room_id, {
+                    "results": [],
+                    "order": rank_map[e.event_id],
+                })
+                rm["results"].append(e.event_id)
+
+                s = sender_group.setdefault(e.sender, {
+                    "results": [],
+                    "order": rank_map[e.event_id],
+                })
+                s["results"].append(e.event_id)
+
+        elif order_by == "recent":
+            # In this case we specifically loop through each room as the given
+            # limit applies to each room, rather than a global list.
+            # This is not necessarilly a good idea.
+            for room_id in room_ids:
+                room_events = []
+                if batch_group == "room_id" and batch_group_key == room_id:
+                    pagination_token = batch_token
+                else:
+                    pagination_token = None
+                i = 0
+
+                # We keep looping and we keep filtering until we reach the limit
+                # or we run out of things.
+                # But only go around 5 times since otherwise synapse will be sad.
+                while len(room_events) < search_filter.limit() and i < 5:
+                    i += 1
+                    results = yield self.store.search_room(
+                        room_id, search_term, keys, search_filter.limit() * 2,
+                        pagination_token=pagination_token,
+                    )
+
+                    results_map = {r["event"].event_id: r for r in results}
+
+                    rank_map.update({r["event"].event_id: r["rank"] for r in results})
+
+                    filtered_events = search_filter.filter([
+                        r["event"] for r in results
+                    ])
+
+                    events = yield self._filter_events_for_client(
+                        user.to_string(), filtered_events
+                    )
+
+                    room_events.extend(events)
+                    room_events = room_events[:search_filter.limit()]
+
+                    if len(results) < search_filter.limit() * 2:
+                        pagination_token = None
+                        break
+                    else:
+                        pagination_token = results[-1]["pagination_token"]
+
+                if room_events:
+                    res = results_map[room_events[-1].event_id]
+                    pagination_token = res["pagination_token"]
+
+                    group = room_groups.setdefault(room_id, {})
+                    if pagination_token:
+                        next_batch = encode_base64("%s\n%s\n%s" % (
+                            "room_id", room_id, pagination_token
+                        ))
+                        group["next_batch"] = next_batch
+
+                        if batch_token:
+                            global_next_batch = next_batch
+
+                    group["results"] = [e.event_id for e in room_events]
+                    group["order"] = max(
+                        e.origin_server_ts/1000 for e in room_events
+                        if hasattr(e, "origin_server_ts")
+                    )
+
+                allowed_events.extend(room_events)
+
+            # Normalize the group orders
+            if room_groups:
+                if len(room_groups) > 1:
+                    mx = max(g["order"] for g in room_groups.values())
+                    mn = min(g["order"] for g in room_groups.values())
+
+                    for g in room_groups.values():
+                        g["order"] = (g["order"] - mn) * 1.0 / (mx - mn)
+                else:
+                    room_groups.values()[0]["order"] = 1
+
+        else:
+            # We should never get here due to the guard earlier.
+            raise NotImplementedError()
+
+        # If client has asked for "context" for each event (i.e. some surrounding
+        # events and state), fetch that
+        if event_context is not None:
+            now_token = yield self.hs.get_event_sources().get_current_token()
+
+            contexts = {}
+            for event in allowed_events:
+                res = yield self.store.get_events_around(
+                    event.room_id, event.event_id, before_limit, after_limit
+                )
+
+                res["events_before"] = yield self._filter_events_for_client(
+                    user.to_string(), res["events_before"]
+                )
+
+                res["events_after"] = yield self._filter_events_for_client(
+                    user.to_string(), res["events_after"]
+                )
+
+                res["start"] = now_token.copy_and_replace(
+                    "room_key", res["start"]
+                ).to_string()
+
+                res["end"] = now_token.copy_and_replace(
+                    "room_key", res["end"]
+                ).to_string()
+
+                contexts[event.event_id] = res
+        else:
+            contexts = {}
+
+        # TODO: Add a limit
+
+        time_now = self.clock.time_msec()
+
+        for context in contexts.values():
+            context["events_before"] = [
+                serialize_event(e, time_now)
+                for e in context["events_before"]
+            ]
+            context["events_after"] = [
+                serialize_event(e, time_now)
+                for e in context["events_after"]
+            ]
+
+        results = {
+            e.event_id: {
+                "rank": rank_map[e.event_id],
+                "result": serialize_event(e, time_now),
+                "context": contexts.get(e.event_id, {}),
+            }
+            for e in allowed_events
+        }
+
+        logger.info("Found %d results", len(results))
+
+        rooms_cat_res = {
+            "results": results,
+            "count": len(results)
+        }
+
+        if room_groups and "room_id" in group_keys:
+            rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups
+
+        if sender_group and "sender" in group_keys:
+            rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
+
+        if global_next_batch:
+            rooms_cat_res["next_batch"] = global_next_batch
+
+        defer.returnValue({
+            "search_categories": {
+                "room_events": rooms_cat_res
+            }
+        })