diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 41cb809078..afd14da112 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -14,8 +14,9 @@
import itertools
import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+import attr
from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.constants import EventTypes, Membership
@@ -32,6 +33,20 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _SearchResult:
+ # The count of results.
+ count: int
+ # A mapping of event ID to the rank of that event.
+ rank_map: Dict[str, int]
+ # A list of the resulting events.
+ allowed_events: List[EventBase]
+ # A map of room ID to results.
+ room_groups: Dict[str, JsonDict]
+ # A set of event IDs to highlight.
+ highlights: Set[str]
+
+
class SearchHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -100,7 +115,7 @@ class SearchHandler:
"""Performs a full text search for a user.
Args:
- user
+ user: The user performing the search.
content: Search parameters
batch: The next_batch parameter. Used for pagination.
@@ -156,6 +171,8 @@ class SearchHandler:
# Include context around each event?
event_context = room_cat.get("event_context", None)
+ before_limit = after_limit = None
+ include_profile = False
# Group results together? May allow clients to paginate within a
# group
@@ -182,6 +199,73 @@ class SearchHandler:
% (set(group_keys) - {"room_id", "sender"},),
)
+ return await self._search(
+ user,
+ batch_group,
+ batch_group_key,
+ batch_token,
+ search_term,
+ keys,
+ filter_dict,
+ order_by,
+ include_state,
+ group_keys,
+ event_context,
+ before_limit,
+ after_limit,
+ include_profile,
+ )
+
+ async def _search(
+ self,
+ user: UserID,
+ batch_group: Optional[str],
+ batch_group_key: Optional[str],
+ batch_token: Optional[str],
+ search_term: str,
+ keys: List[str],
+ filter_dict: JsonDict,
+ order_by: str,
+ include_state: bool,
+ group_keys: List[str],
+ event_context: Optional[bool],
+ before_limit: Optional[int],
+ after_limit: Optional[int],
+ include_profile: bool,
+ ) -> JsonDict:
+ """Performs a full text search for a user.
+
+ Args:
+ user: The user performing the search.
+ batch_group: Pagination information.
+ batch_group_key: Pagination information.
+ batch_token: Pagination information.
+ search_term: Search term to search for
+ keys: List of keys to search in, currently supports
+ "content.body", "content.name", "content.topic"
+ filter_dict: The JSON to build a filter out of.
+ order_by: How to order the results. Valid values ore "rank" and "recent".
+ include_state: True if the state of the room at each result should
+ be included.
+ group_keys: A list of ways to group the results. Valid values are
+ "room_id" and "sender".
+ event_context: True to include contextual events around results.
+ before_limit:
+ The number of events before a result to include as context.
+
+ Only used if event_context is True.
+ after_limit:
+ The number of events after a result to include as context.
+
+ Only used if event_context is True.
+ include_profile: True if historical profile information should be
+ included in the event context.
+
+ Only used if event_context is True.
+
+ Returns:
+ dict to be returned to the client with results of search
+ """
search_filter = Filter(self.hs, filter_dict)
# TODO: Search through left rooms too
@@ -216,278 +300,399 @@ class SearchHandler:
}
}
- rank_map = {} # event_id -> rank of event
- allowed_events = []
- # Holds result of grouping by room, if applicable
- room_groups: Dict[str, JsonDict] = {}
- # Holds result of grouping by sender, if applicable
- sender_group: Dict[str, JsonDict] = {}
+ sender_group: Optional[Dict[str, JsonDict]]
- # Holds the next_batch for the entire result set if one of those exists
- global_next_batch = None
-
- highlights = set()
+ if order_by == "rank":
+ search_result, sender_group = await self._search_by_rank(
+ user, room_ids, search_term, keys, search_filter
+ )
+ # Unused return values for rank search.
+ global_next_batch = None
+ elif order_by == "recent":
+ search_result, global_next_batch = await self._search_by_recent(
+ user,
+ room_ids,
+ search_term,
+ keys,
+ search_filter,
+ batch_group,
+ batch_group_key,
+ batch_token,
+ )
+ # Unused return values for recent search.
+ sender_group = None
+ else:
+ # We should never get here due to the guard earlier.
+ raise NotImplementedError()
- count = None
+ logger.info("Found %d events to return", len(search_result.allowed_events))
- if order_by == "rank":
- search_result = await self.store.search_msgs(room_ids, search_term, keys)
+ # If client has asked for "context" for each event (i.e. some surrounding
+ # events and state), fetch that
+ if event_context is not None:
+ # Note that before and after limit must be set in this case.
+ assert before_limit is not None
+ assert after_limit is not None
+
+ contexts = await self._calculate_event_contexts(
+ user,
+ search_result.allowed_events,
+ before_limit,
+ after_limit,
+ include_profile,
+ )
+ else:
+ contexts = {}
- count = search_result["count"]
+ # TODO: Add a limit
- if search_result["highlights"]:
- highlights.update(search_result["highlights"])
+ state_results = {}
+ if include_state:
+ for room_id in {e.room_id for e in search_result.allowed_events}:
+ state = await self.state_handler.get_current_state(room_id)
+ state_results[room_id] = list(state.values())
- results = search_result["results"]
+ aggregations = None
+ if self._msc3666_enabled:
+ aggregations = await self.store.get_bundled_aggregations(
+ # Generate an iterable of EventBase for all the events that will be
+ # returned, including contextual events.
+ itertools.chain(
+ # The events_before and events_after for each context.
+ itertools.chain.from_iterable(
+ itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
+ for context in contexts.values()
+ ),
+ # The returned events.
+ search_result.allowed_events,
+ ),
+ user.to_string(),
+ )
- rank_map.update({r["event"].event_id: r["rank"] for r in results})
+ # We're now about to serialize the events. We should not make any
+ # blocking calls after this. Otherwise, the 'age' will be wrong.
- filtered_events = await search_filter.filter([r["event"] for r in results])
+ time_now = self.clock.time_msec()
- events = await filter_events_for_client(
- self.storage, user.to_string(), filtered_events
+ for context in contexts.values():
+ context["events_before"] = self._event_serializer.serialize_events(
+ context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
+ )
+ context["events_after"] = self._event_serializer.serialize_events(
+ context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
)
- events.sort(key=lambda e: -rank_map[e.event_id])
- allowed_events = events[: search_filter.limit]
+ results = [
+ {
+ "rank": search_result.rank_map[e.event_id],
+ "result": self._event_serializer.serialize_event(
+ e, time_now, bundle_aggregations=aggregations
+ ),
+ "context": contexts.get(e.event_id, {}),
+ }
+ for e in search_result.allowed_events
+ ]
- for e in allowed_events:
- rm = room_groups.setdefault(
- e.room_id, {"results": [], "order": rank_map[e.event_id]}
- )
- rm["results"].append(e.event_id)
+ rooms_cat_res: JsonDict = {
+ "results": results,
+ "count": search_result.count,
+ "highlights": list(search_result.highlights),
+ }
- s = sender_group.setdefault(
- e.sender, {"results": [], "order": rank_map[e.event_id]}
- )
- s["results"].append(e.event_id)
+ if state_results:
+ rooms_cat_res["state"] = {
+ room_id: self._event_serializer.serialize_events(state_events, time_now)
+ for room_id, state_events in state_results.items()
+ }
- elif order_by == "recent":
- room_events: List[EventBase] = []
- i = 0
-
- pagination_token = batch_token
-
- # We keep looping and we keep filtering until we reach the limit
- # or we run out of things.
- # But only go around 5 times since otherwise synapse will be sad.
- while len(room_events) < search_filter.limit and i < 5:
- i += 1
- search_result = await self.store.search_rooms(
- room_ids,
- search_term,
- keys,
- search_filter.limit * 2,
- pagination_token=pagination_token,
- )
+ if search_result.room_groups and "room_id" in group_keys:
+ rooms_cat_res.setdefault("groups", {})[
+ "room_id"
+ ] = search_result.room_groups
- if search_result["highlights"]:
- highlights.update(search_result["highlights"])
+ if sender_group and "sender" in group_keys:
+ rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
- count = search_result["count"]
+ if global_next_batch:
+ rooms_cat_res["next_batch"] = global_next_batch
- results = search_result["results"]
+ return {"search_categories": {"room_events": rooms_cat_res}}
- results_map = {r["event"].event_id: r for r in results}
+ async def _search_by_rank(
+ self,
+ user: UserID,
+ room_ids: Collection[str],
+ search_term: str,
+ keys: Iterable[str],
+ search_filter: Filter,
+ ) -> Tuple[_SearchResult, Dict[str, JsonDict]]:
+ """
+ Performs a full text search for a user ordering by rank.
- rank_map.update({r["event"].event_id: r["rank"] for r in results})
+ Args:
+ user: The user performing the search.
+ room_ids: List of room ids to search in
+ search_term: Search term to search for
+ keys: List of keys to search in, currently supports
+ "content.body", "content.name", "content.topic"
+ search_filter: The event filter to use.
- filtered_events = await search_filter.filter(
- [r["event"] for r in results]
- )
+ Returns:
+ A tuple of:
+ The search results.
+ A map of sender ID to results.
+ """
+ rank_map = {} # event_id -> rank of event
+ # Holds result of grouping by room, if applicable
+ room_groups: Dict[str, JsonDict] = {}
+ # Holds result of grouping by sender, if applicable
+ sender_group: Dict[str, JsonDict] = {}
- events = await filter_events_for_client(
- self.storage, user.to_string(), filtered_events
- )
+ search_result = await self.store.search_msgs(room_ids, search_term, keys)
- room_events.extend(events)
- room_events = room_events[: search_filter.limit]
+ if search_result["highlights"]:
+ highlights = search_result["highlights"]
+ else:
+ highlights = set()
- if len(results) < search_filter.limit * 2:
- pagination_token = None
- break
- else:
- pagination_token = results[-1]["pagination_token"]
-
- for event in room_events:
- group = room_groups.setdefault(event.room_id, {"results": []})
- group["results"].append(event.event_id)
-
- if room_events and len(room_events) >= search_filter.limit:
- last_event_id = room_events[-1].event_id
- pagination_token = results_map[last_event_id]["pagination_token"]
-
- # We want to respect the given batch group and group keys so
- # that if people blindly use the top level `next_batch` token
- # it returns more from the same group (if applicable) rather
- # than reverting to searching all results again.
- if batch_group and batch_group_key:
- global_next_batch = encode_base64(
- (
- "%s\n%s\n%s"
- % (batch_group, batch_group_key, pagination_token)
- ).encode("ascii")
- )
- else:
- global_next_batch = encode_base64(
- ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
- )
+ results = search_result["results"]
- for room_id, group in room_groups.items():
- group["next_batch"] = encode_base64(
- ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
- "ascii"
- )
- )
+ # event_id -> rank of event
+ rank_map = {r["event"].event_id: r["rank"] for r in results}
- allowed_events.extend(room_events)
+ filtered_events = await search_filter.filter([r["event"] for r in results])
- else:
- # We should never get here due to the guard earlier.
- raise NotImplementedError()
+ events = await filter_events_for_client(
+ self.storage, user.to_string(), filtered_events
+ )
- logger.info("Found %d events to return", len(allowed_events))
+ events.sort(key=lambda e: -rank_map[e.event_id])
+ allowed_events = events[: search_filter.limit]
- # If client has asked for "context" for each event (i.e. some surrounding
- # events and state), fetch that
- if event_context is not None:
- now_token = self.hs.get_event_sources().get_current_token()
+ for e in allowed_events:
+ rm = room_groups.setdefault(
+ e.room_id, {"results": [], "order": rank_map[e.event_id]}
+ )
+ rm["results"].append(e.event_id)
- contexts = {}
- for event in allowed_events:
- res = await self.store.get_events_around(
- event.room_id, event.event_id, before_limit, after_limit
- )
+ s = sender_group.setdefault(
+ e.sender, {"results": [], "order": rank_map[e.event_id]}
+ )
+ s["results"].append(e.event_id)
+
+ return (
+ _SearchResult(
+ search_result["count"],
+ rank_map,
+ allowed_events,
+ room_groups,
+ highlights,
+ ),
+ sender_group,
+ )
- logger.info(
- "Context for search returned %d and %d events",
- len(res.events_before),
- len(res.events_after),
- )
+ async def _search_by_recent(
+ self,
+ user: UserID,
+ room_ids: Collection[str],
+ search_term: str,
+ keys: Iterable[str],
+ search_filter: Filter,
+ batch_group: Optional[str],
+ batch_group_key: Optional[str],
+ batch_token: Optional[str],
+ ) -> Tuple[_SearchResult, Optional[str]]:
+ """
+ Performs a full text search for a user ordering by recent.
- events_before = await filter_events_for_client(
- self.storage, user.to_string(), res.events_before
- )
+ Args:
+ user: The user performing the search.
+ room_ids: List of room ids to search in
+ search_term: Search term to search for
+ keys: List of keys to search in, currently supports
+ "content.body", "content.name", "content.topic"
+ search_filter: The event filter to use.
+ batch_group: Pagination information.
+ batch_group_key: Pagination information.
+ batch_token: Pagination information.
- events_after = await filter_events_for_client(
- self.storage, user.to_string(), res.events_after
- )
+ Returns:
+ A tuple of:
+ The search results.
+ Optionally, a pagination token.
+ """
+ rank_map = {} # event_id -> rank of event
+ # Holds result of grouping by room, if applicable
+ room_groups: Dict[str, JsonDict] = {}
- context = {
- "events_before": events_before,
- "events_after": events_after,
- "start": await now_token.copy_and_replace(
- "room_key", res.start
- ).to_string(self.store),
- "end": await now_token.copy_and_replace(
- "room_key", res.end
- ).to_string(self.store),
- }
+ # Holds the next_batch for the entire result set if one of those exists
+ global_next_batch = None
- if include_profile:
- senders = {
- ev.sender
- for ev in itertools.chain(events_before, [event], events_after)
- }
+ highlights = set()
- if events_after:
- last_event_id = events_after[-1].event_id
- else:
- last_event_id = event.event_id
+ room_events: List[EventBase] = []
+ i = 0
+
+ pagination_token = batch_token
+
+ # We keep looping and we keep filtering until we reach the limit
+ # or we run out of things.
+ # But only go around 5 times since otherwise synapse will be sad.
+ while len(room_events) < search_filter.limit and i < 5:
+ i += 1
+ search_result = await self.store.search_rooms(
+ room_ids,
+ search_term,
+ keys,
+ search_filter.limit * 2,
+ pagination_token=pagination_token,
+ )
- state_filter = StateFilter.from_types(
- [(EventTypes.Member, sender) for sender in senders]
- )
+ if search_result["highlights"]:
+ highlights.update(search_result["highlights"])
+
+ count = search_result["count"]
+
+ results = search_result["results"]
+
+ results_map = {r["event"].event_id: r for r in results}
+
+ rank_map.update({r["event"].event_id: r["rank"] for r in results})
+
+ filtered_events = await search_filter.filter([r["event"] for r in results])
- state = await self.state_store.get_state_for_event(
- last_event_id, state_filter
+ events = await filter_events_for_client(
+ self.storage, user.to_string(), filtered_events
+ )
+
+ room_events.extend(events)
+ room_events = room_events[: search_filter.limit]
+
+ if len(results) < search_filter.limit * 2:
+ break
+ else:
+ pagination_token = results[-1]["pagination_token"]
+
+ for event in room_events:
+ group = room_groups.setdefault(event.room_id, {"results": []})
+ group["results"].append(event.event_id)
+
+ if room_events and len(room_events) >= search_filter.limit:
+ last_event_id = room_events[-1].event_id
+ pagination_token = results_map[last_event_id]["pagination_token"]
+
+ # We want to respect the given batch group and group keys so
+ # that if people blindly use the top level `next_batch` token
+ # it returns more from the same group (if applicable) rather
+ # than reverting to searching all results again.
+ if batch_group and batch_group_key:
+ global_next_batch = encode_base64(
+ (
+ "%s\n%s\n%s" % (batch_group, batch_group_key, pagination_token)
+ ).encode("ascii")
+ )
+ else:
+ global_next_batch = encode_base64(
+ ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
+ )
+
+ for room_id, group in room_groups.items():
+ group["next_batch"] = encode_base64(
+ ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
+ "ascii"
)
+ )
- context["profile_info"] = {
- s.state_key: {
- "displayname": s.content.get("displayname", None),
- "avatar_url": s.content.get("avatar_url", None),
- }
- for s in state.values()
- if s.type == EventTypes.Member and s.state_key in senders
- }
+ return (
+ _SearchResult(count, rank_map, room_events, room_groups, highlights),
+ global_next_batch,
+ )
- contexts[event.event_id] = context
- else:
- contexts = {}
+ async def _calculate_event_contexts(
+ self,
+ user: UserID,
+ allowed_events: List[EventBase],
+ before_limit: int,
+ after_limit: int,
+ include_profile: bool,
+ ) -> Dict[str, JsonDict]:
+ """
+ Calculates the contextual events for any search results.
- # TODO: Add a limit
+ Args:
+ user: The user performing the search.
+ allowed_events: The search results.
+ before_limit:
+ The number of events before a result to include as context.
+ after_limit:
+ The number of events after a result to include as context.
+ include_profile: True if historical profile information should be
+ included in the event context.
- time_now = self.clock.time_msec()
+ Returns:
+ A map of event ID to contextual information.
+ """
+ now_token = self.hs.get_event_sources().get_current_token()
- aggregations = None
- if self._msc3666_enabled:
- aggregations = await self.store.get_bundled_aggregations(
- # Generate an iterable of EventBase for all the events that will be
- # returned, including contextual events.
- itertools.chain(
- # The events_before and events_after for each context.
- itertools.chain.from_iterable(
- itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
- for context in contexts.values()
- ),
- # The returned events.
- allowed_events,
- ),
- user.to_string(),
+ contexts = {}
+ for event in allowed_events:
+ res = await self.store.get_events_around(
+ event.room_id, event.event_id, before_limit, after_limit
)
- for context in contexts.values():
- context["events_before"] = self._event_serializer.serialize_events(
- context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
+ logger.info(
+ "Context for search returned %d and %d events",
+ len(res.events_before),
+ len(res.events_after),
)
- context["events_after"] = self._event_serializer.serialize_events(
- context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
+
+ events_before = await filter_events_for_client(
+ self.storage, user.to_string(), res.events_before
)
- state_results = {}
- if include_state:
- for room_id in {e.room_id for e in allowed_events}:
- state = await self.state_handler.get_current_state(room_id)
- state_results[room_id] = list(state.values())
+ events_after = await filter_events_for_client(
+ self.storage, user.to_string(), res.events_after
+ )
- # We're now about to serialize the events. We should not make any
- # blocking calls after this. Otherwise the 'age' will be wrong
+ context = {
+ "events_before": events_before,
+ "events_after": events_after,
+ "start": await now_token.copy_and_replace(
+ "room_key", res.start
+ ).to_string(self.store),
+ "end": await now_token.copy_and_replace("room_key", res.end).to_string(
+ self.store
+ ),
+ }
- results = []
- for e in allowed_events:
- results.append(
- {
- "rank": rank_map[e.event_id],
- "result": self._event_serializer.serialize_event(
- e, time_now, bundle_aggregations=aggregations
- ),
- "context": contexts.get(e.event_id, {}),
+ if include_profile:
+ senders = {
+ ev.sender
+ for ev in itertools.chain(events_before, [event], events_after)
}
- )
- rooms_cat_res = {
- "results": results,
- "count": count,
- "highlights": list(highlights),
- }
+ if events_after:
+ last_event_id = events_after[-1].event_id
+ else:
+ last_event_id = event.event_id
- if state_results:
- s = {}
- for room_id, state_events in state_results.items():
- s[room_id] = self._event_serializer.serialize_events(
- state_events, time_now
+ state_filter = StateFilter.from_types(
+ [(EventTypes.Member, sender) for sender in senders]
)
- rooms_cat_res["state"] = s
-
- if room_groups and "room_id" in group_keys:
- rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups
+ state = await self.state_store.get_state_for_event(
+ last_event_id, state_filter
+ )
- if sender_group and "sender" in group_keys:
- rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
+ context["profile_info"] = {
+ s.state_key: {
+ "displayname": s.content.get("displayname", None),
+ "avatar_url": s.content.get("avatar_url", None),
+ }
+ for s in state.values()
+ if s.type == EventTypes.Member and s.state_key in senders
+ }
- if global_next_batch:
- rooms_cat_res["next_batch"] = global_next_batch
+ contexts[event.event_id] = context
- return {"search_categories": {"room_events": rooms_cat_res}}
+ return contexts
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 2d085a5764..acea300ed3 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -28,6 +28,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -381,17 +382,19 @@ class SearchStore(SearchBackgroundUpdateStore):
):
super().__init__(database, db_conn, hs)
- async def search_msgs(self, room_ids, search_term, keys):
+ async def search_msgs(
+ self, room_ids: Collection[str], search_term: str, keys: Iterable[str]
+ ) -> JsonDict:
"""Performs a full text search over events with given keys.
Args:
- room_ids (list): List of room ids to search in
- search_term (str): Search term to search for
- keys (list): List of keys to search in, currently supports
+ room_ids: List of room ids to search in
+ search_term: Search term to search for
+ keys: List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
Returns:
- list of dicts
+ Dictionary of results
"""
clauses = []
@@ -499,10 +502,10 @@ class SearchStore(SearchBackgroundUpdateStore):
self,
room_ids: Collection[str],
search_term: str,
- keys: List[str],
+ keys: Iterable[str],
limit,
pagination_token: Optional[str] = None,
- ) -> List[dict]:
+ ) -> JsonDict:
"""Performs a full text search over events with given keys.
Args:
|