summary refs log tree commit diff
path: root/synapse/handlers/search.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/search.py')
-rw-r--r--synapse/handlers/search.py100
1 files changed, 55 insertions, 45 deletions
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index cd5e90bacb..4d40d3ac9c 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -18,10 +18,8 @@ import logging
 
 from unpaddedbase64 import decode_base64, encode_base64
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
 from synapse.storage.state import StateFilter
 from synapse.visibility import filter_events_for_client
@@ -35,9 +33,11 @@ class SearchHandler(BaseHandler):
     def __init__(self, hs):
         super(SearchHandler, self).__init__(hs)
         self._event_serializer = hs.get_event_client_serializer()
+        self.storage = hs.get_storage()
+        self.state_store = self.storage.state
+        self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def get_old_rooms_from_upgraded_room(self, room_id):
+    async def get_old_rooms_from_upgraded_room(self, room_id):
         """Retrieves room IDs of old rooms in the history of an upgraded room.
 
         We do so by checking the m.room.create event of the room for a
@@ -51,28 +51,42 @@ class SearchHandler(BaseHandler):
             room_id (str): id of the room to search through.
 
         Returns:
-            Deferred[iterable[unicode]]: predecessor room ids
+            Deferred[iterable[str]]: predecessor room ids
         """
 
         historical_room_ids = []
 
-        while True:
-            predecessor = yield self.store.get_room_predecessor(room_id)
+        # The initial room must have been known for us to get this far
+        predecessor = await self.store.get_room_predecessor(room_id)
 
-            # If no predecessor, assume we've hit a dead end
+        while True:
             if not predecessor:
+                # We have reached the end of the chain of predecessors
+                break
+
+            if not isinstance(predecessor.get("room_id"), str):
+                # This predecessor object is malformed. Exit here
+                break
+
+            predecessor_room_id = predecessor["room_id"]
+
+            # Don't add it to the list until we have checked that we are in the room
+            try:
+                next_predecessor_room = await self.store.get_room_predecessor(
+                    predecessor_room_id
+                )
+            except NotFoundError:
+                # The predecessor is not a known room, so we are done here
                 break
 
-            # Add predecessor's room ID
-            historical_room_ids.append(predecessor["room_id"])
+            historical_room_ids.append(predecessor_room_id)
 
-            # Scan through the old room for further predecessors
-            room_id = predecessor["room_id"]
+            # And repeat
+            predecessor = next_predecessor_room
 
         return historical_room_ids
 
-    @defer.inlineCallbacks
-    def search(self, user, content, batch=None):
+    async def search(self, user, content, batch=None):
         """Performs a full text search for a user.
 
         Args:
@@ -161,12 +175,12 @@ class SearchHandler(BaseHandler):
         search_filter = Filter(filter_dict)
 
         # TODO: Search through left rooms too
-        rooms = yield self.store.get_rooms_for_user_where_membership_is(
+        rooms = await self.store.get_rooms_for_local_user_where_membership_is(
             user.to_string(),
             membership_list=[Membership.JOIN],
             # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
         )
-        room_ids = set(r.room_id for r in rooms)
+        room_ids = {r.room_id for r in rooms}
 
         # If doing a subset of all rooms seearch, check if any of the rooms
         # are from an upgraded room, and search their contents as well
@@ -174,7 +188,7 @@ class SearchHandler(BaseHandler):
             historical_room_ids = []
             for room_id in search_filter.rooms:
                 # Add any previous rooms to the search if they exist
-                ids = yield self.get_old_rooms_from_upgraded_room(room_id)
+                ids = await self.get_old_rooms_from_upgraded_room(room_id)
                 historical_room_ids += ids
 
             # Prevent any historical events from being filtered
@@ -205,7 +219,7 @@ class SearchHandler(BaseHandler):
         count = None
 
         if order_by == "rank":
-            search_result = yield self.store.search_msgs(room_ids, search_term, keys)
+            search_result = await self.store.search_msgs(room_ids, search_term, keys)
 
             count = search_result["count"]
 
@@ -220,8 +234,8 @@ class SearchHandler(BaseHandler):
 
             filtered_events = search_filter.filter([r["event"] for r in results])
 
-            events = yield filter_events_for_client(
-                self.store, user.to_string(), filtered_events
+            events = await filter_events_for_client(
+                self.storage, user.to_string(), filtered_events
             )
 
             events.sort(key=lambda e: -rank_map[e.event_id])
@@ -249,7 +263,7 @@ class SearchHandler(BaseHandler):
             # But only go around 5 times since otherwise synapse will be sad.
             while len(room_events) < search_filter.limit() and i < 5:
                 i += 1
-                search_result = yield self.store.search_rooms(
+                search_result = await self.store.search_rooms(
                     room_ids,
                     search_term,
                     keys,
@@ -270,8 +284,8 @@ class SearchHandler(BaseHandler):
 
                 filtered_events = search_filter.filter([r["event"] for r in results])
 
-                events = yield filter_events_for_client(
-                    self.store, user.to_string(), filtered_events
+                events = await filter_events_for_client(
+                    self.storage, user.to_string(), filtered_events
                 )
 
                 room_events.extend(events)
@@ -325,11 +339,11 @@ class SearchHandler(BaseHandler):
         # If client has asked for "context" for each event (i.e. some surrounding
         # events and state), fetch that
         if event_context is not None:
-            now_token = yield self.hs.get_event_sources().get_current_token()
+            now_token = await self.hs.get_event_sources().get_current_token()
 
             contexts = {}
             for event in allowed_events:
-                res = yield self.store.get_events_around(
+                res = await self.store.get_events_around(
                     event.room_id, event.event_id, before_limit, after_limit
                 )
 
@@ -339,12 +353,12 @@ class SearchHandler(BaseHandler):
                     len(res["events_after"]),
                 )
 
-                res["events_before"] = yield filter_events_for_client(
-                    self.store, user.to_string(), res["events_before"]
+                res["events_before"] = await filter_events_for_client(
+                    self.storage, user.to_string(), res["events_before"]
                 )
 
-                res["events_after"] = yield filter_events_for_client(
-                    self.store, user.to_string(), res["events_after"]
+                res["events_after"] = await filter_events_for_client(
+                    self.storage, user.to_string(), res["events_after"]
                 )
 
                 res["start"] = now_token.copy_and_replace(
@@ -356,12 +370,12 @@ class SearchHandler(BaseHandler):
                 ).to_string()
 
                 if include_profile:
-                    senders = set(
+                    senders = {
                         ev.sender
                         for ev in itertools.chain(
                             res["events_before"], [event], res["events_after"]
                         )
-                    )
+                    }
 
                     if res["events_after"]:
                         last_event_id = res["events_after"][-1].event_id
@@ -372,7 +386,7 @@ class SearchHandler(BaseHandler):
                         [(EventTypes.Member, sender) for sender in senders]
                     )
 
-                    state = yield self.store.get_state_for_event(
+                    state = await self.state_store.get_state_for_event(
                         last_event_id, state_filter
                     )
 
@@ -394,22 +408,18 @@ class SearchHandler(BaseHandler):
         time_now = self.clock.time_msec()
 
         for context in contexts.values():
-            context["events_before"] = (
-                yield self._event_serializer.serialize_events(
-                    context["events_before"], time_now
-                )
+            context["events_before"] = await self._event_serializer.serialize_events(
+                context["events_before"], time_now
             )
-            context["events_after"] = (
-                yield self._event_serializer.serialize_events(
-                    context["events_after"], time_now
-                )
+            context["events_after"] = await self._event_serializer.serialize_events(
+                context["events_after"], time_now
             )
 
         state_results = {}
         if include_state:
-            rooms = set(e.room_id for e in allowed_events)
+            rooms = {e.room_id for e in allowed_events}
             for room_id in rooms:
-                state = yield self.state_handler.get_current_state(room_id)
+                state = await self.state_handler.get_current_state(room_id)
                 state_results[room_id] = list(state.values())
 
             state_results.values()
@@ -423,7 +433,7 @@ class SearchHandler(BaseHandler):
                 {
                     "rank": rank_map[e.event_id],
                     "result": (
-                        yield self._event_serializer.serialize_event(e, time_now)
+                        await self._event_serializer.serialize_event(e, time_now)
                     ),
                     "context": contexts.get(e.event_id, {}),
                 }
@@ -438,7 +448,7 @@ class SearchHandler(BaseHandler):
         if state_results:
             s = {}
             for room_id, state in state_results.items():
-                s[room_id] = yield self._event_serializer.serialize_events(
+                s[room_id] = await self._event_serializer.serialize_events(
                     state, time_now
                 )