diff options
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r-- | synapse/storage/databases/main/search.py | 197 |
1 files changed, 35 insertions, 162 deletions
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index a89fc54c2c..1b79acf955 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -11,22 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import enum + import logging import re -from collections import deque -from dataclasses import dataclass -from typing import ( - TYPE_CHECKING, - Any, - Collection, - Iterable, - List, - Optional, - Set, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple import attr @@ -39,7 +27,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.types import JsonDict if TYPE_CHECKING: @@ -433,6 +421,8 @@ class SearchStore(SearchBackgroundUpdateStore): """ clauses = [] + search_query = _parse_query(self.database_engine, search_term) + args: List[Any] = [] # Make sure we don't explode because the person is in too many rooms. @@ -454,24 +444,20 @@ class SearchStore(SearchBackgroundUpdateStore): count_clauses = clauses if isinstance(self.database_engine, PostgresEngine): - search_query = search_term - tsquery_func = self.database_engine.tsquery_func sql = ( - f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank," + "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank," " room_id, event_id" " FROM event_search" - f" WHERE vector @@ {tsquery_func}('english', ?)" + " WHERE vector @@ to_tsquery('english', ?)" ) args = [search_query, search_query] + args count_sql = ( "SELECT room_id, count(*) as count FROM event_search" - f" WHERE vector @@ {tsquery_func}('english', ?)" + " WHERE vector @@ to_tsquery('english', ?)" ) count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): - search_query = _parse_query_for_sqlite(search_term) - sql = ( "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" " FROM event_search" @@ -483,7 +469,7 @@ class SearchStore(SearchBackgroundUpdateStore): "SELECT room_id, count(*) as count FROM event_search" " WHERE value MATCH ?" ) - count_args = [search_query] + count_args + count_args = [search_term] + count_args else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -515,9 +501,7 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = await self._find_highlights_in_postgres( - search_query, events, tsquery_func - ) + highlights = await self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" @@ -526,6 +510,7 @@ class SearchStore(SearchBackgroundUpdateStore): ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) + return { "results": [ {"event": event_map[r["event_id"]], "rank": r["rank"]} @@ -557,6 +542,9 @@ class SearchStore(SearchBackgroundUpdateStore): Each match as a dictionary. """ clauses = [] + + search_query = _parse_query(self.database_engine, search_term) + args: List[Any] = [] # Make sure we don't explode because the person is in too many rooms. @@ -594,23 +582,20 @@ class SearchStore(SearchBackgroundUpdateStore): args.extend([origin_server_ts, origin_server_ts, stream]) if isinstance(self.database_engine, PostgresEngine): - search_query = search_term - tsquery_func = self.database_engine.tsquery_func sql = ( - f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank," + "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank," " origin_server_ts, stream_ordering, room_id, event_id" " FROM event_search" - f" WHERE vector @@ {tsquery_func}('english', ?) AND " + " WHERE vector @@ to_tsquery('english', ?) AND " ) args = [search_query, search_query] + args count_sql = ( "SELECT room_id, count(*) as count FROM event_search" - f" WHERE vector @@ {tsquery_func}('english', ?) AND " + " WHERE vector @@ to_tsquery('english', ?) AND " ) count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): - # We use CROSS JOIN here to ensure we use the right indexes. # https://sqlite.org/optoverview.html#crossjoin # @@ -629,14 +614,13 @@ class SearchStore(SearchBackgroundUpdateStore): " CROSS JOIN events USING (event_id)" " WHERE " ) - search_query = _parse_query_for_sqlite(search_term) args = [search_query] + args count_sql = ( "SELECT room_id, count(*) as count FROM event_search" " WHERE value MATCH ? AND " ) - count_args = [search_query] + count_args + count_args = [search_term] + count_args else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -676,9 +660,7 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = await self._find_highlights_in_postgres( - search_query, events, tsquery_func - ) + highlights = await self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" @@ -704,7 +686,7 @@ class SearchStore(SearchBackgroundUpdateStore): } async def _find_highlights_in_postgres( - self, search_query: str, events: List[EventBase], tsquery_func: str + self, search_query: str, events: List[EventBase] ) -> Set[str]: """Given a list of events and a search term, return a list of words that match from the content of the event. @@ -715,7 +697,6 @@ class SearchStore(SearchBackgroundUpdateStore): Args: search_query events: A list of events - tsquery_func: The tsquery_* function to use when making queries Returns: A set of strings. @@ -748,7 +729,7 @@ class SearchStore(SearchBackgroundUpdateStore): while stop_sel in value: stop_sel += ">" - query = f"SELECT ts_headline(?, {tsquery_func}('english', ?), %s)" % ( + query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % ( _to_postgres_options( { "StartSel": start_sel, @@ -779,128 +760,20 @@ def _to_postgres_options(options_dict: JsonDict) -> str: return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),) -@dataclass -class Phrase: - phrase: List[str] - - -class SearchToken(enum.Enum): - Not = enum.auto() - Or = enum.auto() - And = enum.auto() - - -Token = Union[str, Phrase, SearchToken] -TokenList = List[Token] - - -def _is_stop_word(word: str) -> bool: - # TODO Pull these out of the dictionary: - # https://github.com/postgres/postgres/blob/master/src/backend/snowball/stopwords/english.stop - return word in {"the", "a", "you", "me", "and", "but"} - - -def _tokenize_query(query: str) -> TokenList: - """ - Convert the user-supplied `query` into a TokenList, which can be translated into - some DB-specific syntax. - - The following constructs are supported: - - - phrase queries using "double quotes" - - case-insensitive `or` and `and` operators - - negation of a keyword via unary `-` - - unary hyphen to denote NOT e.g. 'include -exclude' - - The following differs from websearch_to_tsquery: - - - Stop words are not removed. - - Unclosed phrases are treated differently. - - """ - tokens: TokenList = [] - - # Find phrases. - in_phrase = False - parts = deque(query.split('"')) - for i, part in enumerate(parts): - # The contents inside double quotes is treated as a phrase, a trailing - # double quote is not implied. - in_phrase = bool(i % 2) and i != (len(parts) - 1) - - # Pull out the individual words, discarding any non-word characters. - words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE)) - - # Phrases have simplified handling of words. - if in_phrase: - # Skip stop words. - phrase = [word for word in words if not _is_stop_word(word)] - - # Consecutive words are implicitly ANDed together. - if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or): - tokens.append(SearchToken.And) - - # Add the phrase. - tokens.append(Phrase(phrase)) - continue - - # Otherwise, not in a phrase. - while words: - word = words.popleft() - - if word.startswith("-"): - tokens.append(SearchToken.Not) - - # If there's more word, put it back to be processed again. - word = word[1:] - if word: - words.appendleft(word) - elif word.lower() == "or": - tokens.append(SearchToken.Or) - else: - # Skip stop words. - if _is_stop_word(word): - continue - - # Consecutive words are implicitly ANDed together. - if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or): - tokens.append(SearchToken.And) - - # Add the search term. - tokens.append(word) - - return tokens - - -def _tokens_to_sqlite_match_query(tokens: TokenList) -> str: - """ - Convert the list of tokens to a string suitable for passing to sqlite's MATCH. - Assume sqlite was compiled with enhanced query syntax. - - Ref: https://www.sqlite.org/fts3.html#full_text_index_queries +def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str: + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. """ - match_query = [] - for token in tokens: - if isinstance(token, str): - match_query.append(token) - elif isinstance(token, Phrase): - match_query.append('"' + " ".join(token.phrase) + '"') - elif token == SearchToken.Not: - # TODO: SQLite treats NOT as a *binary* operator. Hopefully a search - # term has already been added before this. - match_query.append(" NOT ") - elif token == SearchToken.Or: - match_query.append(" OR ") - elif token == SearchToken.And: - match_query.append(" AND ") - else: - raise ValueError(f"unknown token {token}") - - return "".join(match_query) + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) -def _parse_query_for_sqlite(search_term: str) -> str: - """Takes a plain unicode string from the user and converts it into a form - that can be passed to sqllite's matchinfo(). - """ - return _tokens_to_sqlite_match_query(_tokenize_query(search_term)) + if isinstance(database_engine, PostgresEngine): + return " & ".join(result + ":*" for result in results) + elif isinstance(database_engine, Sqlite3Engine): + return " & ".join(result + "*" for result in results) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") |