summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-02-15 08:47:05 -0500
committerGitHub <noreply@github.com>2022-02-15 13:47:05 +0000
commite44f91d678e22936b7e2f0d8bf4890159507533b (patch)
tree33541d791cffea42062d28209695e2107ee6a63f /synapse
parentFix incorrect thread summaries when the latest event is edited. (#11992) (diff)
downloadsynapse-e44f91d678e22936b7e2f0d8bf4890159507533b.tar.xz
Refactor search code to reduce function size. (#11991)
Splits the search code into a few logical functions instead of a single
unreadable function.

There are also a few additional changes for readability.

After refactoring it was clear to see there were some unused and
unnecessary variables, which were simplified.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/search.py643
-rw-r--r--synapse/storage/databases/main/search.py17
2 files changed, 434 insertions, 226 deletions
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 41cb809078..afd14da112 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -14,8 +14,9 @@
 
 import itertools
 import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
 
+import attr
 from unpaddedbase64 import decode_base64, encode_base64
 
 from synapse.api.constants import EventTypes, Membership
@@ -32,6 +33,20 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _SearchResult:
+    # The count of results.
+    count: int
+    # A mapping of event ID to the rank of that event.
+    rank_map: Dict[str, int]
+    # A list of the resulting events.
+    allowed_events: List[EventBase]
+    # A map of room ID to results.
+    room_groups: Dict[str, JsonDict]
+    # A set of event IDs to highlight.
+    highlights: Set[str]
+
+
 class SearchHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -100,7 +115,7 @@ class SearchHandler:
         """Performs a full text search for a user.
 
         Args:
-            user
+            user: The user performing the search.
             content: Search parameters
             batch: The next_batch parameter. Used for pagination.
 
@@ -156,6 +171,8 @@ class SearchHandler:
 
             # Include context around each event?
             event_context = room_cat.get("event_context", None)
+            before_limit = after_limit = None
+            include_profile = False
 
             # Group results together? May allow clients to paginate within a
             # group
@@ -182,6 +199,73 @@ class SearchHandler:
                 % (set(group_keys) - {"room_id", "sender"},),
             )
 
+        return await self._search(
+            user,
+            batch_group,
+            batch_group_key,
+            batch_token,
+            search_term,
+            keys,
+            filter_dict,
+            order_by,
+            include_state,
+            group_keys,
+            event_context,
+            before_limit,
+            after_limit,
+            include_profile,
+        )
+
+    async def _search(
+        self,
+        user: UserID,
+        batch_group: Optional[str],
+        batch_group_key: Optional[str],
+        batch_token: Optional[str],
+        search_term: str,
+        keys: List[str],
+        filter_dict: JsonDict,
+        order_by: str,
+        include_state: bool,
+        group_keys: List[str],
+        event_context: Optional[bool],
+        before_limit: Optional[int],
+        after_limit: Optional[int],
+        include_profile: bool,
+    ) -> JsonDict:
+        """Performs a full text search for a user.
+
+        Args:
+            user: The user performing the search.
+            batch_group: Pagination information.
+            batch_group_key: Pagination information.
+            batch_token: Pagination information.
+            search_term: Search term to search for
+            keys: List of keys to search in, currently supports
+                "content.body", "content.name", "content.topic"
+            filter_dict: The JSON to build a filter out of.
+            order_by: How to order the results. Valid values ore "rank" and "recent".
+            include_state: True if the state of the room at each result should
+                be included.
+            group_keys: A list of ways to group the results. Valid values are
+                "room_id" and "sender".
+            event_context: True to include contextual events around results.
+            before_limit:
+                The number of events before a result to include as context.
+
+                Only used if event_context is True.
+            after_limit:
+                The number of events after a result to include as context.
+
+                Only used if event_context is True.
+            include_profile: True if historical profile information should be
+                included in the event context.
+
+                Only used if event_context is True.
+
+        Returns:
+            dict to be returned to the client with results of search
+        """
         search_filter = Filter(self.hs, filter_dict)
 
         # TODO: Search through left rooms too
@@ -216,278 +300,399 @@ class SearchHandler:
                 }
             }
 
-        rank_map = {}  # event_id -> rank of event
-        allowed_events = []
-        # Holds result of grouping by room, if applicable
-        room_groups: Dict[str, JsonDict] = {}
-        # Holds result of grouping by sender, if applicable
-        sender_group: Dict[str, JsonDict] = {}
+        sender_group: Optional[Dict[str, JsonDict]]
 
-        # Holds the next_batch for the entire result set if one of those exists
-        global_next_batch = None
-
-        highlights = set()
+        if order_by == "rank":
+            search_result, sender_group = await self._search_by_rank(
+                user, room_ids, search_term, keys, search_filter
+            )
+            # Unused return values for rank search.
+            global_next_batch = None
+        elif order_by == "recent":
+            search_result, global_next_batch = await self._search_by_recent(
+                user,
+                room_ids,
+                search_term,
+                keys,
+                search_filter,
+                batch_group,
+                batch_group_key,
+                batch_token,
+            )
+            # Unused return values for recent search.
+            sender_group = None
+        else:
+            # We should never get here due to the guard earlier.
+            raise NotImplementedError()
 
-        count = None
+        logger.info("Found %d events to return", len(search_result.allowed_events))
 
-        if order_by == "rank":
-            search_result = await self.store.search_msgs(room_ids, search_term, keys)
+        # If client has asked for "context" for each event (i.e. some surrounding
+        # events and state), fetch that
+        if event_context is not None:
+            # Note that before and after limit must be set in this case.
+            assert before_limit is not None
+            assert after_limit is not None
+
+            contexts = await self._calculate_event_contexts(
+                user,
+                search_result.allowed_events,
+                before_limit,
+                after_limit,
+                include_profile,
+            )
+        else:
+            contexts = {}
 
-            count = search_result["count"]
+        # TODO: Add a limit
 
-            if search_result["highlights"]:
-                highlights.update(search_result["highlights"])
+        state_results = {}
+        if include_state:
+            for room_id in {e.room_id for e in search_result.allowed_events}:
+                state = await self.state_handler.get_current_state(room_id)
+                state_results[room_id] = list(state.values())
 
-            results = search_result["results"]
+        aggregations = None
+        if self._msc3666_enabled:
+            aggregations = await self.store.get_bundled_aggregations(
+                # Generate an iterable of EventBase for all the events that will be
+                # returned, including contextual events.
+                itertools.chain(
+                    # The events_before and events_after for each context.
+                    itertools.chain.from_iterable(
+                        itertools.chain(context["events_before"], context["events_after"])  # type: ignore[arg-type]
+                        for context in contexts.values()
+                    ),
+                    # The returned events.
+                    search_result.allowed_events,
+                ),
+                user.to_string(),
+            )
 
-            rank_map.update({r["event"].event_id: r["rank"] for r in results})
+        # We're now about to serialize the events. We should not make any
+        # blocking calls after this. Otherwise, the 'age' will be wrong.
 
-            filtered_events = await search_filter.filter([r["event"] for r in results])
+        time_now = self.clock.time_msec()
 
-            events = await filter_events_for_client(
-                self.storage, user.to_string(), filtered_events
+        for context in contexts.values():
+            context["events_before"] = self._event_serializer.serialize_events(
+                context["events_before"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
+            )
+            context["events_after"] = self._event_serializer.serialize_events(
+                context["events_after"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
             )
 
-            events.sort(key=lambda e: -rank_map[e.event_id])
-            allowed_events = events[: search_filter.limit]
+        results = [
+            {
+                "rank": search_result.rank_map[e.event_id],
+                "result": self._event_serializer.serialize_event(
+                    e, time_now, bundle_aggregations=aggregations
+                ),
+                "context": contexts.get(e.event_id, {}),
+            }
+            for e in search_result.allowed_events
+        ]
 
-            for e in allowed_events:
-                rm = room_groups.setdefault(
-                    e.room_id, {"results": [], "order": rank_map[e.event_id]}
-                )
-                rm["results"].append(e.event_id)
+        rooms_cat_res: JsonDict = {
+            "results": results,
+            "count": search_result.count,
+            "highlights": list(search_result.highlights),
+        }
 
-                s = sender_group.setdefault(
-                    e.sender, {"results": [], "order": rank_map[e.event_id]}
-                )
-                s["results"].append(e.event_id)
+        if state_results:
+            rooms_cat_res["state"] = {
+                room_id: self._event_serializer.serialize_events(state_events, time_now)
+                for room_id, state_events in state_results.items()
+            }
 
-        elif order_by == "recent":
-            room_events: List[EventBase] = []
-            i = 0
-
-            pagination_token = batch_token
-
-            # We keep looping and we keep filtering until we reach the limit
-            # or we run out of things.
-            # But only go around 5 times since otherwise synapse will be sad.
-            while len(room_events) < search_filter.limit and i < 5:
-                i += 1
-                search_result = await self.store.search_rooms(
-                    room_ids,
-                    search_term,
-                    keys,
-                    search_filter.limit * 2,
-                    pagination_token=pagination_token,
-                )
+        if search_result.room_groups and "room_id" in group_keys:
+            rooms_cat_res.setdefault("groups", {})[
+                "room_id"
+            ] = search_result.room_groups
 
-                if search_result["highlights"]:
-                    highlights.update(search_result["highlights"])
+        if sender_group and "sender" in group_keys:
+            rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
 
-                count = search_result["count"]
+        if global_next_batch:
+            rooms_cat_res["next_batch"] = global_next_batch
 
-                results = search_result["results"]
+        return {"search_categories": {"room_events": rooms_cat_res}}
 
-                results_map = {r["event"].event_id: r for r in results}
+    async def _search_by_rank(
+        self,
+        user: UserID,
+        room_ids: Collection[str],
+        search_term: str,
+        keys: Iterable[str],
+        search_filter: Filter,
+    ) -> Tuple[_SearchResult, Dict[str, JsonDict]]:
+        """
+        Performs a full text search for a user ordering by rank.
 
-                rank_map.update({r["event"].event_id: r["rank"] for r in results})
+        Args:
+            user: The user performing the search.
+            room_ids: List of room ids to search in
+            search_term: Search term to search for
+            keys: List of keys to search in, currently supports
+                "content.body", "content.name", "content.topic"
+            search_filter: The event filter to use.
 
-                filtered_events = await search_filter.filter(
-                    [r["event"] for r in results]
-                )
+        Returns:
+            A tuple of:
+                The search results.
+                A map of sender ID to results.
+        """
+        rank_map = {}  # event_id -> rank of event
+        # Holds result of grouping by room, if applicable
+        room_groups: Dict[str, JsonDict] = {}
+        # Holds result of grouping by sender, if applicable
+        sender_group: Dict[str, JsonDict] = {}
 
-                events = await filter_events_for_client(
-                    self.storage, user.to_string(), filtered_events
-                )
+        search_result = await self.store.search_msgs(room_ids, search_term, keys)
 
-                room_events.extend(events)
-                room_events = room_events[: search_filter.limit]
+        if search_result["highlights"]:
+            highlights = search_result["highlights"]
+        else:
+            highlights = set()
 
-                if len(results) < search_filter.limit * 2:
-                    pagination_token = None
-                    break
-                else:
-                    pagination_token = results[-1]["pagination_token"]
-
-            for event in room_events:
-                group = room_groups.setdefault(event.room_id, {"results": []})
-                group["results"].append(event.event_id)
-
-            if room_events and len(room_events) >= search_filter.limit:
-                last_event_id = room_events[-1].event_id
-                pagination_token = results_map[last_event_id]["pagination_token"]
-
-                # We want to respect the given batch group and group keys so
-                # that if people blindly use the top level `next_batch` token
-                # it returns more from the same group (if applicable) rather
-                # than reverting to searching all results again.
-                if batch_group and batch_group_key:
-                    global_next_batch = encode_base64(
-                        (
-                            "%s\n%s\n%s"
-                            % (batch_group, batch_group_key, pagination_token)
-                        ).encode("ascii")
-                    )
-                else:
-                    global_next_batch = encode_base64(
-                        ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
-                    )
+        results = search_result["results"]
 
-                for room_id, group in room_groups.items():
-                    group["next_batch"] = encode_base64(
-                        ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
-                            "ascii"
-                        )
-                    )
+        # event_id -> rank of event
+        rank_map = {r["event"].event_id: r["rank"] for r in results}
 
-            allowed_events.extend(room_events)
+        filtered_events = await search_filter.filter([r["event"] for r in results])
 
-        else:
-            # We should never get here due to the guard earlier.
-            raise NotImplementedError()
+        events = await filter_events_for_client(
+            self.storage, user.to_string(), filtered_events
+        )
 
-        logger.info("Found %d events to return", len(allowed_events))
+        events.sort(key=lambda e: -rank_map[e.event_id])
+        allowed_events = events[: search_filter.limit]
 
-        # If client has asked for "context" for each event (i.e. some surrounding
-        # events and state), fetch that
-        if event_context is not None:
-            now_token = self.hs.get_event_sources().get_current_token()
+        for e in allowed_events:
+            rm = room_groups.setdefault(
+                e.room_id, {"results": [], "order": rank_map[e.event_id]}
+            )
+            rm["results"].append(e.event_id)
 
-            contexts = {}
-            for event in allowed_events:
-                res = await self.store.get_events_around(
-                    event.room_id, event.event_id, before_limit, after_limit
-                )
+            s = sender_group.setdefault(
+                e.sender, {"results": [], "order": rank_map[e.event_id]}
+            )
+            s["results"].append(e.event_id)
+
+        return (
+            _SearchResult(
+                search_result["count"],
+                rank_map,
+                allowed_events,
+                room_groups,
+                highlights,
+            ),
+            sender_group,
+        )
 
-                logger.info(
-                    "Context for search returned %d and %d events",
-                    len(res.events_before),
-                    len(res.events_after),
-                )
+    async def _search_by_recent(
+        self,
+        user: UserID,
+        room_ids: Collection[str],
+        search_term: str,
+        keys: Iterable[str],
+        search_filter: Filter,
+        batch_group: Optional[str],
+        batch_group_key: Optional[str],
+        batch_token: Optional[str],
+    ) -> Tuple[_SearchResult, Optional[str]]:
+        """
+        Performs a full text search for a user ordering by recent.
 
-                events_before = await filter_events_for_client(
-                    self.storage, user.to_string(), res.events_before
-                )
+        Args:
+            user: The user performing the search.
+            room_ids: List of room ids to search in
+            search_term: Search term to search for
+            keys: List of keys to search in, currently supports
+                "content.body", "content.name", "content.topic"
+            search_filter: The event filter to use.
+            batch_group: Pagination information.
+            batch_group_key: Pagination information.
+            batch_token: Pagination information.
 
-                events_after = await filter_events_for_client(
-                    self.storage, user.to_string(), res.events_after
-                )
+        Returns:
+            A tuple of:
+                The search results.
+                Optionally, a pagination token.
+        """
+        rank_map = {}  # event_id -> rank of event
+        # Holds result of grouping by room, if applicable
+        room_groups: Dict[str, JsonDict] = {}
 
-                context = {
-                    "events_before": events_before,
-                    "events_after": events_after,
-                    "start": await now_token.copy_and_replace(
-                        "room_key", res.start
-                    ).to_string(self.store),
-                    "end": await now_token.copy_and_replace(
-                        "room_key", res.end
-                    ).to_string(self.store),
-                }
+        # Holds the next_batch for the entire result set if one of those exists
+        global_next_batch = None
 
-                if include_profile:
-                    senders = {
-                        ev.sender
-                        for ev in itertools.chain(events_before, [event], events_after)
-                    }
+        highlights = set()
 
-                    if events_after:
-                        last_event_id = events_after[-1].event_id
-                    else:
-                        last_event_id = event.event_id
+        room_events: List[EventBase] = []
+        i = 0
+
+        pagination_token = batch_token
+
+        # We keep looping and we keep filtering until we reach the limit
+        # or we run out of things.
+        # But only go around 5 times since otherwise synapse will be sad.
+        while len(room_events) < search_filter.limit and i < 5:
+            i += 1
+            search_result = await self.store.search_rooms(
+                room_ids,
+                search_term,
+                keys,
+                search_filter.limit * 2,
+                pagination_token=pagination_token,
+            )
 
-                    state_filter = StateFilter.from_types(
-                        [(EventTypes.Member, sender) for sender in senders]
-                    )
+            if search_result["highlights"]:
+                highlights.update(search_result["highlights"])
+
+            count = search_result["count"]
+
+            results = search_result["results"]
+
+            results_map = {r["event"].event_id: r for r in results}
+
+            rank_map.update({r["event"].event_id: r["rank"] for r in results})
+
+            filtered_events = await search_filter.filter([r["event"] for r in results])
 
-                    state = await self.state_store.get_state_for_event(
-                        last_event_id, state_filter
+            events = await filter_events_for_client(
+                self.storage, user.to_string(), filtered_events
+            )
+
+            room_events.extend(events)
+            room_events = room_events[: search_filter.limit]
+
+            if len(results) < search_filter.limit * 2:
+                break
+            else:
+                pagination_token = results[-1]["pagination_token"]
+
+        for event in room_events:
+            group = room_groups.setdefault(event.room_id, {"results": []})
+            group["results"].append(event.event_id)
+
+        if room_events and len(room_events) >= search_filter.limit:
+            last_event_id = room_events[-1].event_id
+            pagination_token = results_map[last_event_id]["pagination_token"]
+
+            # We want to respect the given batch group and group keys so
+            # that if people blindly use the top level `next_batch` token
+            # it returns more from the same group (if applicable) rather
+            # than reverting to searching all results again.
+            if batch_group and batch_group_key:
+                global_next_batch = encode_base64(
+                    (
+                        "%s\n%s\n%s" % (batch_group, batch_group_key, pagination_token)
+                    ).encode("ascii")
+                )
+            else:
+                global_next_batch = encode_base64(
+                    ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
+                )
+
+            for room_id, group in room_groups.items():
+                group["next_batch"] = encode_base64(
+                    ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
+                        "ascii"
                     )
+                )
 
-                    context["profile_info"] = {
-                        s.state_key: {
-                            "displayname": s.content.get("displayname", None),
-                            "avatar_url": s.content.get("avatar_url", None),
-                        }
-                        for s in state.values()
-                        if s.type == EventTypes.Member and s.state_key in senders
-                    }
+        return (
+            _SearchResult(count, rank_map, room_events, room_groups, highlights),
+            global_next_batch,
+        )
 
-                contexts[event.event_id] = context
-        else:
-            contexts = {}
+    async def _calculate_event_contexts(
+        self,
+        user: UserID,
+        allowed_events: List[EventBase],
+        before_limit: int,
+        after_limit: int,
+        include_profile: bool,
+    ) -> Dict[str, JsonDict]:
+        """
+        Calculates the contextual events for any search results.
 
-        # TODO: Add a limit
+        Args:
+            user: The user performing the search.
+            allowed_events: The search results.
+            before_limit:
+                The number of events before a result to include as context.
+            after_limit:
+                The number of events after a result to include as context.
+            include_profile: True if historical profile information should be
+                included in the event context.
 
-        time_now = self.clock.time_msec()
+        Returns:
+            A map of event ID to contextual information.
+        """
+        now_token = self.hs.get_event_sources().get_current_token()
 
-        aggregations = None
-        if self._msc3666_enabled:
-            aggregations = await self.store.get_bundled_aggregations(
-                # Generate an iterable of EventBase for all the events that will be
-                # returned, including contextual events.
-                itertools.chain(
-                    # The events_before and events_after for each context.
-                    itertools.chain.from_iterable(
-                        itertools.chain(context["events_before"], context["events_after"])  # type: ignore[arg-type]
-                        for context in contexts.values()
-                    ),
-                    # The returned events.
-                    allowed_events,
-                ),
-                user.to_string(),
+        contexts = {}
+        for event in allowed_events:
+            res = await self.store.get_events_around(
+                event.room_id, event.event_id, before_limit, after_limit
             )
 
-        for context in contexts.values():
-            context["events_before"] = self._event_serializer.serialize_events(
-                context["events_before"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
+            logger.info(
+                "Context for search returned %d and %d events",
+                len(res.events_before),
+                len(res.events_after),
             )
-            context["events_after"] = self._event_serializer.serialize_events(
-                context["events_after"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
+
+            events_before = await filter_events_for_client(
+                self.storage, user.to_string(), res.events_before
             )
 
-        state_results = {}
-        if include_state:
-            for room_id in {e.room_id for e in allowed_events}:
-                state = await self.state_handler.get_current_state(room_id)
-                state_results[room_id] = list(state.values())
+            events_after = await filter_events_for_client(
+                self.storage, user.to_string(), res.events_after
+            )
 
-        # We're now about to serialize the events. We should not make any
-        # blocking calls after this. Otherwise the 'age' will be wrong
+            context = {
+                "events_before": events_before,
+                "events_after": events_after,
+                "start": await now_token.copy_and_replace(
+                    "room_key", res.start
+                ).to_string(self.store),
+                "end": await now_token.copy_and_replace("room_key", res.end).to_string(
+                    self.store
+                ),
+            }
 
-        results = []
-        for e in allowed_events:
-            results.append(
-                {
-                    "rank": rank_map[e.event_id],
-                    "result": self._event_serializer.serialize_event(
-                        e, time_now, bundle_aggregations=aggregations
-                    ),
-                    "context": contexts.get(e.event_id, {}),
+            if include_profile:
+                senders = {
+                    ev.sender
+                    for ev in itertools.chain(events_before, [event], events_after)
                 }
-            )
 
-        rooms_cat_res = {
-            "results": results,
-            "count": count,
-            "highlights": list(highlights),
-        }
+                if events_after:
+                    last_event_id = events_after[-1].event_id
+                else:
+                    last_event_id = event.event_id
 
-        if state_results:
-            s = {}
-            for room_id, state_events in state_results.items():
-                s[room_id] = self._event_serializer.serialize_events(
-                    state_events, time_now
+                state_filter = StateFilter.from_types(
+                    [(EventTypes.Member, sender) for sender in senders]
                 )
 
-            rooms_cat_res["state"] = s
-
-        if room_groups and "room_id" in group_keys:
-            rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups
+                state = await self.state_store.get_state_for_event(
+                    last_event_id, state_filter
+                )
 
-        if sender_group and "sender" in group_keys:
-            rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
+                context["profile_info"] = {
+                    s.state_key: {
+                        "displayname": s.content.get("displayname", None),
+                        "avatar_url": s.content.get("avatar_url", None),
+                    }
+                    for s in state.values()
+                    if s.type == EventTypes.Member and s.state_key in senders
+                }
 
-        if global_next_batch:
-            rooms_cat_res["next_batch"] = global_next_batch
+            contexts[event.event_id] = context
 
-        return {"search_categories": {"room_events": rooms_cat_res}}
+        return contexts
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 2d085a5764..acea300ed3 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -28,6 +28,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import JsonDict
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -381,17 +382,19 @@ class SearchStore(SearchBackgroundUpdateStore):
     ):
         super().__init__(database, db_conn, hs)
 
-    async def search_msgs(self, room_ids, search_term, keys):
+    async def search_msgs(
+        self, room_ids: Collection[str], search_term: str, keys: Iterable[str]
+    ) -> JsonDict:
         """Performs a full text search over events with given keys.
 
         Args:
-            room_ids (list): List of room ids to search in
-            search_term (str): Search term to search for
-            keys (list): List of keys to search in, currently supports
+            room_ids: List of room ids to search in
+            search_term: Search term to search for
+            keys: List of keys to search in, currently supports
                 "content.body", "content.name", "content.topic"
 
         Returns:
-            list of dicts
+            Dictionary of results
         """
         clauses = []
 
@@ -499,10 +502,10 @@ class SearchStore(SearchBackgroundUpdateStore):
         self,
         room_ids: Collection[str],
         search_term: str,
-        keys: List[str],
+        keys: Iterable[str],
         limit,
         pagination_token: Optional[str] = None,
-    ) -> List[dict]:
+    ) -> JsonDict:
         """Performs a full text search over events with given keys.
 
         Args: