diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index cd5e90bacb..4d40d3ac9c 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -18,10 +18,8 @@ import logging
from unpaddedbase64 import decode_base64, encode_base64
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.storage.state import StateFilter
from synapse.visibility import filter_events_for_client
@@ -35,9 +33,11 @@ class SearchHandler(BaseHandler):
def __init__(self, hs):
super(SearchHandler, self).__init__(hs)
self._event_serializer = hs.get_event_client_serializer()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
+ self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def get_old_rooms_from_upgraded_room(self, room_id):
+ async def get_old_rooms_from_upgraded_room(self, room_id):
"""Retrieves room IDs of old rooms in the history of an upgraded room.
We do so by checking the m.room.create event of the room for a
@@ -51,28 +51,42 @@ class SearchHandler(BaseHandler):
room_id (str): id of the room to search through.
Returns:
- Deferred[iterable[unicode]]: predecessor room ids
+ Deferred[iterable[str]]: predecessor room ids
"""
historical_room_ids = []
- while True:
- predecessor = yield self.store.get_room_predecessor(room_id)
+ # The initial room must have been known for us to get this far
+ predecessor = await self.store.get_room_predecessor(room_id)
- # If no predecessor, assume we've hit a dead end
+ while True:
if not predecessor:
+ # We have reached the end of the chain of predecessors
+ break
+
+ if not isinstance(predecessor.get("room_id"), str):
+ # This predecessor object is malformed. Exit here
+ break
+
+ predecessor_room_id = predecessor["room_id"]
+
+ # Don't add it to the list until we have checked that we are in the room
+ try:
+ next_predecessor_room = await self.store.get_room_predecessor(
+ predecessor_room_id
+ )
+ except NotFoundError:
+ # The predecessor is not a known room, so we are done here
break
- # Add predecessor's room ID
- historical_room_ids.append(predecessor["room_id"])
+ historical_room_ids.append(predecessor_room_id)
- # Scan through the old room for further predecessors
- room_id = predecessor["room_id"]
+ # And repeat
+ predecessor = next_predecessor_room
return historical_room_ids
- @defer.inlineCallbacks
- def search(self, user, content, batch=None):
+ async def search(self, user, content, batch=None):
"""Performs a full text search for a user.
Args:
@@ -161,12 +175,12 @@ class SearchHandler(BaseHandler):
search_filter = Filter(filter_dict)
# TODO: Search through left rooms too
- rooms = yield self.store.get_rooms_for_user_where_membership_is(
+ rooms = await self.store.get_rooms_for_local_user_where_membership_is(
user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
)
- room_ids = set(r.room_id for r in rooms)
+ room_ids = {r.room_id for r in rooms}
# If doing a subset of all rooms seearch, check if any of the rooms
# are from an upgraded room, and search their contents as well
@@ -174,7 +188,7 @@ class SearchHandler(BaseHandler):
historical_room_ids = []
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
- ids = yield self.get_old_rooms_from_upgraded_room(room_id)
+ ids = await self.get_old_rooms_from_upgraded_room(room_id)
historical_room_ids += ids
# Prevent any historical events from being filtered
@@ -205,7 +219,7 @@ class SearchHandler(BaseHandler):
count = None
if order_by == "rank":
- search_result = yield self.store.search_msgs(room_ids, search_term, keys)
+ search_result = await self.store.search_msgs(room_ids, search_term, keys)
count = search_result["count"]
@@ -220,8 +234,8 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results])
- events = yield filter_events_for_client(
- self.store, user.to_string(), filtered_events
+ events = await filter_events_for_client(
+ self.storage, user.to_string(), filtered_events
)
events.sort(key=lambda e: -rank_map[e.event_id])
@@ -249,7 +263,7 @@ class SearchHandler(BaseHandler):
# But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit() and i < 5:
i += 1
- search_result = yield self.store.search_rooms(
+ search_result = await self.store.search_rooms(
room_ids,
search_term,
keys,
@@ -270,8 +284,8 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results])
- events = yield filter_events_for_client(
- self.store, user.to_string(), filtered_events
+ events = await filter_events_for_client(
+ self.storage, user.to_string(), filtered_events
)
room_events.extend(events)
@@ -325,11 +339,11 @@ class SearchHandler(BaseHandler):
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None:
- now_token = yield self.hs.get_event_sources().get_current_token()
+ now_token = await self.hs.get_event_sources().get_current_token()
contexts = {}
for event in allowed_events:
- res = yield self.store.get_events_around(
+ res = await self.store.get_events_around(
event.room_id, event.event_id, before_limit, after_limit
)
@@ -339,12 +353,12 @@ class SearchHandler(BaseHandler):
len(res["events_after"]),
)
- res["events_before"] = yield filter_events_for_client(
- self.store, user.to_string(), res["events_before"]
+ res["events_before"] = await filter_events_for_client(
+ self.storage, user.to_string(), res["events_before"]
)
- res["events_after"] = yield filter_events_for_client(
- self.store, user.to_string(), res["events_after"]
+ res["events_after"] = await filter_events_for_client(
+ self.storage, user.to_string(), res["events_after"]
)
res["start"] = now_token.copy_and_replace(
@@ -356,12 +370,12 @@ class SearchHandler(BaseHandler):
).to_string()
if include_profile:
- senders = set(
+ senders = {
ev.sender
for ev in itertools.chain(
res["events_before"], [event], res["events_after"]
)
- )
+ }
if res["events_after"]:
last_event_id = res["events_after"][-1].event_id
@@ -372,7 +386,7 @@ class SearchHandler(BaseHandler):
[(EventTypes.Member, sender) for sender in senders]
)
- state = yield self.store.get_state_for_event(
+ state = await self.state_store.get_state_for_event(
last_event_id, state_filter
)
@@ -394,22 +408,18 @@ class SearchHandler(BaseHandler):
time_now = self.clock.time_msec()
for context in contexts.values():
- context["events_before"] = (
- yield self._event_serializer.serialize_events(
- context["events_before"], time_now
- )
+ context["events_before"] = await self._event_serializer.serialize_events(
+ context["events_before"], time_now
)
- context["events_after"] = (
- yield self._event_serializer.serialize_events(
- context["events_after"], time_now
- )
+ context["events_after"] = await self._event_serializer.serialize_events(
+ context["events_after"], time_now
)
state_results = {}
if include_state:
- rooms = set(e.room_id for e in allowed_events)
+ rooms = {e.room_id for e in allowed_events}
for room_id in rooms:
- state = yield self.state_handler.get_current_state(room_id)
+ state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())
state_results.values()
@@ -423,7 +433,7 @@ class SearchHandler(BaseHandler):
{
"rank": rank_map[e.event_id],
"result": (
- yield self._event_serializer.serialize_event(e, time_now)
+ await self._event_serializer.serialize_event(e, time_now)
),
"context": contexts.get(e.event_id, {}),
}
@@ -438,7 +448,7 @@ class SearchHandler(BaseHandler):
if state_results:
s = {}
for room_id, state in state_results.items():
- s[room_id] = yield self._event_serializer.serialize_events(
+ s[room_id] = await self._event_serializer.serialize_events(
state, time_now
)
|