summary refs log tree commit diff
path: root/synapse/handlers/search.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/search.py')
-rw-r--r--synapse/handlers/search.py129
1 files changed, 62 insertions, 67 deletions
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 9bba74d6c9..ddc4430d03 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -32,7 +32,6 @@ logger = logging.getLogger(__name__)
 
 
 class SearchHandler(BaseHandler):
-
     def __init__(self, hs):
         super(SearchHandler, self).__init__(hs)
         self._event_serializer = hs.get_event_client_serializer()
@@ -93,7 +92,7 @@ class SearchHandler(BaseHandler):
         batch_token = None
         if batch:
             try:
-                b = decode_base64(batch).decode('ascii')
+                b = decode_base64(batch).decode("ascii")
                 batch_group, batch_group_key, batch_token = b.split("\n")
 
                 assert batch_group is not None
@@ -104,7 +103,9 @@ class SearchHandler(BaseHandler):
 
         logger.info(
             "Search batch properties: %r, %r, %r",
-            batch_group, batch_group_key, batch_token,
+            batch_group,
+            batch_group_key,
+            batch_token,
         )
 
         logger.info("Search content: %s", content)
@@ -116,9 +117,9 @@ class SearchHandler(BaseHandler):
             search_term = room_cat["search_term"]
 
             # Which "keys" to search over in FTS query
-            keys = room_cat.get("keys", [
-                "content.body", "content.name", "content.topic",
-            ])
+            keys = room_cat.get(
+                "keys", ["content.body", "content.name", "content.topic"]
+            )
 
             # Filter to apply to results
             filter_dict = room_cat.get("filter", {})
@@ -130,9 +131,7 @@ class SearchHandler(BaseHandler):
             include_state = room_cat.get("include_state", False)
 
             # Include context around each event?
-            event_context = room_cat.get(
-                "event_context", None
-            )
+            event_context = room_cat.get("event_context", None)
 
             # Group results together? May allow clients to paginate within a
             # group
@@ -140,12 +139,8 @@ class SearchHandler(BaseHandler):
             group_keys = [g["key"] for g in group_by]
 
             if event_context is not None:
-                before_limit = int(event_context.get(
-                    "before_limit", 5
-                ))
-                after_limit = int(event_context.get(
-                    "after_limit", 5
-                ))
+                before_limit = int(event_context.get("before_limit", 5))
+                after_limit = int(event_context.get("after_limit", 5))
 
                 # Return the historic display name and avatar for the senders
                 # of the events?
@@ -159,7 +154,8 @@ class SearchHandler(BaseHandler):
         if set(group_keys) - {"room_id", "sender"}:
             raise SynapseError(
                 400,
-                "Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},)
+                "Invalid group by keys: %r"
+                % (set(group_keys) - {"room_id", "sender"},),
             )
 
         search_filter = Filter(filter_dict)
@@ -190,15 +186,13 @@ class SearchHandler(BaseHandler):
             room_ids.intersection_update({batch_group_key})
 
         if not room_ids:
-            defer.returnValue({
-                "search_categories": {
-                    "room_events": {
-                        "results": [],
-                        "count": 0,
-                        "highlights": [],
+            defer.returnValue(
+                {
+                    "search_categories": {
+                        "room_events": {"results": [], "count": 0, "highlights": []}
                     }
                 }
-            })
+            )
 
         rank_map = {}  # event_id -> rank of event
         allowed_events = []
@@ -213,9 +207,7 @@ class SearchHandler(BaseHandler):
         count = None
 
         if order_by == "rank":
-            search_result = yield self.store.search_msgs(
-                room_ids, search_term, keys
-            )
+            search_result = yield self.store.search_msgs(room_ids, search_term, keys)
 
             count = search_result["count"]
 
@@ -235,19 +227,17 @@ class SearchHandler(BaseHandler):
             )
 
             events.sort(key=lambda e: -rank_map[e.event_id])
-            allowed_events = events[:search_filter.limit()]
+            allowed_events = events[: search_filter.limit()]
 
             for e in allowed_events:
-                rm = room_groups.setdefault(e.room_id, {
-                    "results": [],
-                    "order": rank_map[e.event_id],
-                })
+                rm = room_groups.setdefault(
+                    e.room_id, {"results": [], "order": rank_map[e.event_id]}
+                )
                 rm["results"].append(e.event_id)
 
-                s = sender_group.setdefault(e.sender, {
-                    "results": [],
-                    "order": rank_map[e.event_id],
-                })
+                s = sender_group.setdefault(
+                    e.sender, {"results": [], "order": rank_map[e.event_id]}
+                )
                 s["results"].append(e.event_id)
 
         elif order_by == "recent":
@@ -262,7 +252,10 @@ class SearchHandler(BaseHandler):
             while len(room_events) < search_filter.limit() and i < 5:
                 i += 1
                 search_result = yield self.store.search_rooms(
-                    room_ids, search_term, keys, search_filter.limit() * 2,
+                    room_ids,
+                    search_term,
+                    keys,
+                    search_filter.limit() * 2,
                     pagination_token=pagination_token,
                 )
 
@@ -277,16 +270,14 @@ class SearchHandler(BaseHandler):
 
                 rank_map.update({r["event"].event_id: r["rank"] for r in results})
 
-                filtered_events = search_filter.filter([
-                    r["event"] for r in results
-                ])
+                filtered_events = search_filter.filter([r["event"] for r in results])
 
                 events = yield filter_events_for_client(
                     self.store, user.to_string(), filtered_events
                 )
 
                 room_events.extend(events)
-                room_events = room_events[:search_filter.limit()]
+                room_events = room_events[: search_filter.limit()]
 
                 if len(results) < search_filter.limit() * 2:
                     pagination_token = None
@@ -295,9 +286,7 @@ class SearchHandler(BaseHandler):
                     pagination_token = results[-1]["pagination_token"]
 
             for event in room_events:
-                group = room_groups.setdefault(event.room_id, {
-                    "results": [],
-                })
+                group = room_groups.setdefault(event.room_id, {"results": []})
                 group["results"].append(event.event_id)
 
             if room_events and len(room_events) >= search_filter.limit():
@@ -309,18 +298,23 @@ class SearchHandler(BaseHandler):
                 # it returns more from the same group (if applicable) rather
                 # than reverting to searching all results again.
                 if batch_group and batch_group_key:
-                    global_next_batch = encode_base64(("%s\n%s\n%s" % (
-                        batch_group, batch_group_key, pagination_token
-                    )).encode('ascii'))
+                    global_next_batch = encode_base64(
+                        (
+                            "%s\n%s\n%s"
+                            % (batch_group, batch_group_key, pagination_token)
+                        ).encode("ascii")
+                    )
                 else:
-                    global_next_batch = encode_base64(("%s\n%s\n%s" % (
-                        "all", "", pagination_token
-                    )).encode('ascii'))
+                    global_next_batch = encode_base64(
+                        ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
+                    )
 
                 for room_id, group in room_groups.items():
-                    group["next_batch"] = encode_base64(("%s\n%s\n%s" % (
-                        "room_id", room_id, pagination_token
-                    )).encode('ascii'))
+                    group["next_batch"] = encode_base64(
+                        ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
+                            "ascii"
+                        )
+                    )
 
             allowed_events.extend(room_events)
 
@@ -338,12 +332,13 @@ class SearchHandler(BaseHandler):
             contexts = {}
             for event in allowed_events:
                 res = yield self.store.get_events_around(
-                    event.room_id, event.event_id, before_limit, after_limit,
+                    event.room_id, event.event_id, before_limit, after_limit
                 )
 
                 logger.info(
                     "Context for search returned %d and %d events",
-                    len(res["events_before"]), len(res["events_after"]),
+                    len(res["events_before"]),
+                    len(res["events_after"]),
                 )
 
                 res["events_before"] = yield filter_events_for_client(
@@ -403,12 +398,12 @@ class SearchHandler(BaseHandler):
         for context in contexts.values():
             context["events_before"] = (
                 yield self._event_serializer.serialize_events(
-                    context["events_before"], time_now,
+                    context["events_before"], time_now
                 )
             )
             context["events_after"] = (
                 yield self._event_serializer.serialize_events(
-                    context["events_after"], time_now,
+                    context["events_after"], time_now
                 )
             )
 
@@ -426,11 +421,15 @@ class SearchHandler(BaseHandler):
 
         results = []
         for e in allowed_events:
-            results.append({
-                "rank": rank_map[e.event_id],
-                "result": (yield self._event_serializer.serialize_event(e, time_now)),
-                "context": contexts.get(e.event_id, {}),
-            })
+            results.append(
+                {
+                    "rank": rank_map[e.event_id],
+                    "result": (
+                        yield self._event_serializer.serialize_event(e, time_now)
+                    ),
+                    "context": contexts.get(e.event_id, {}),
+                }
+            )
 
         rooms_cat_res = {
             "results": results,
@@ -442,7 +441,7 @@ class SearchHandler(BaseHandler):
             s = {}
             for room_id, state in state_results.items():
                 s[room_id] = yield self._event_serializer.serialize_events(
-                    state, time_now,
+                    state, time_now
                 )
 
             rooms_cat_res["state"] = s
@@ -456,8 +455,4 @@ class SearchHandler(BaseHandler):
         if global_next_batch:
             rooms_cat_res["next_batch"] = global_next_batch
 
-        defer.returnValue({
-            "search_categories": {
-                "room_events": rooms_cat_res
-            }
-        })
+        defer.returnValue({"search_categories": {"room_events": rooms_cat_res}})