summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/search.py197
-rw-r--r--synapse/storage/engines/postgres.py16
-rw-r--r--synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py62
3 files changed, 240 insertions, 35 deletions
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 1b79acf955..a89fc54c2c 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -11,10 +11,22 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import enum
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
+from collections import deque
+from dataclasses import dataclass
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    Union,
+)
 
 import attr
 
@@ -27,7 +39,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -421,8 +433,6 @@ class SearchStore(SearchBackgroundUpdateStore):
         """
         clauses = []
 
-        search_query = _parse_query(self.database_engine, search_term)
-
         args: List[Any] = []
 
         # Make sure we don't explode because the person is in too many rooms.
@@ -444,20 +454,24 @@ class SearchStore(SearchBackgroundUpdateStore):
         count_clauses = clauses
 
         if isinstance(self.database_engine, PostgresEngine):
+            search_query = search_term
+            tsquery_func = self.database_engine.tsquery_func
             sql = (
-                "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank,"
+                f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank,"
                 " room_id, event_id"
                 " FROM event_search"
-                " WHERE vector @@ to_tsquery('english', ?)"
+                f" WHERE vector @@  {tsquery_func}('english', ?)"
             )
             args = [search_query, search_query] + args
 
             count_sql = (
                 "SELECT room_id, count(*) as count FROM event_search"
-                " WHERE vector @@ to_tsquery('english', ?)"
+                f" WHERE vector @@ {tsquery_func}('english', ?)"
             )
             count_args = [search_query] + count_args
         elif isinstance(self.database_engine, Sqlite3Engine):
+            search_query = _parse_query_for_sqlite(search_term)
+
             sql = (
                 "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
                 " FROM event_search"
@@ -469,7 +483,7 @@ class SearchStore(SearchBackgroundUpdateStore):
                 "SELECT room_id, count(*) as count FROM event_search"
                 " WHERE value MATCH ?"
             )
-            count_args = [search_term] + count_args
+            count_args = [search_query] + count_args
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
@@ -501,7 +515,9 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         highlights = None
         if isinstance(self.database_engine, PostgresEngine):
-            highlights = await self._find_highlights_in_postgres(search_query, events)
+            highlights = await self._find_highlights_in_postgres(
+                search_query, events, tsquery_func
+            )
 
         count_sql += " GROUP BY room_id"
 
@@ -510,7 +526,6 @@ class SearchStore(SearchBackgroundUpdateStore):
         )
 
         count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
-
         return {
             "results": [
                 {"event": event_map[r["event_id"]], "rank": r["rank"]}
@@ -542,9 +557,6 @@ class SearchStore(SearchBackgroundUpdateStore):
             Each match as a dictionary.
         """
         clauses = []
-
-        search_query = _parse_query(self.database_engine, search_term)
-
         args: List[Any] = []
 
         # Make sure we don't explode because the person is in too many rooms.
@@ -582,20 +594,23 @@ class SearchStore(SearchBackgroundUpdateStore):
             args.extend([origin_server_ts, origin_server_ts, stream])
 
         if isinstance(self.database_engine, PostgresEngine):
+            search_query = search_term
+            tsquery_func = self.database_engine.tsquery_func
             sql = (
-                "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank,"
+                f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank,"
                 " origin_server_ts, stream_ordering, room_id, event_id"
                 " FROM event_search"
-                " WHERE vector @@ to_tsquery('english', ?) AND "
+                f" WHERE vector @@ {tsquery_func}('english', ?) AND "
             )
             args = [search_query, search_query] + args
 
             count_sql = (
                 "SELECT room_id, count(*) as count FROM event_search"
-                " WHERE vector @@ to_tsquery('english', ?) AND "
+                f" WHERE vector @@ {tsquery_func}('english', ?) AND "
             )
             count_args = [search_query] + count_args
         elif isinstance(self.database_engine, Sqlite3Engine):
+
             # We use CROSS JOIN here to ensure we use the right indexes.
             # https://sqlite.org/optoverview.html#crossjoin
             #
@@ -614,13 +629,14 @@ class SearchStore(SearchBackgroundUpdateStore):
                 " CROSS JOIN events USING (event_id)"
                 " WHERE "
             )
+            search_query = _parse_query_for_sqlite(search_term)
             args = [search_query] + args
 
             count_sql = (
                 "SELECT room_id, count(*) as count FROM event_search"
                 " WHERE value MATCH ? AND "
             )
-            count_args = [search_term] + count_args
+            count_args = [search_query] + count_args
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
@@ -660,7 +676,9 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         highlights = None
         if isinstance(self.database_engine, PostgresEngine):
-            highlights = await self._find_highlights_in_postgres(search_query, events)
+            highlights = await self._find_highlights_in_postgres(
+                search_query, events, tsquery_func
+            )
 
         count_sql += " GROUP BY room_id"
 
@@ -686,7 +704,7 @@ class SearchStore(SearchBackgroundUpdateStore):
         }
 
     async def _find_highlights_in_postgres(
-        self, search_query: str, events: List[EventBase]
+        self, search_query: str, events: List[EventBase], tsquery_func: str
     ) -> Set[str]:
         """Given a list of events and a search term, return a list of words
         that match from the content of the event.
@@ -697,6 +715,7 @@ class SearchStore(SearchBackgroundUpdateStore):
         Args:
             search_query
             events: A list of events
+            tsquery_func: The tsquery_* function to use when making queries
 
         Returns:
             A set of strings.
@@ -729,7 +748,7 @@ class SearchStore(SearchBackgroundUpdateStore):
                 while stop_sel in value:
                     stop_sel += ">"
 
-                query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
+                query = f"SELECT ts_headline(?, {tsquery_func}('english', ?), %s)" % (
                     _to_postgres_options(
                         {
                             "StartSel": start_sel,
@@ -760,20 +779,128 @@ def _to_postgres_options(options_dict: JsonDict) -> str:
     return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
 
 
-def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
-    """Takes a plain unicode string from the user and converts it into a form
-    that can be passed to database.
-    We use this so that we can add prefix matching, which isn't something
-    that is supported by default.
+@dataclass
+class Phrase:
+    phrase: List[str]
+
+
+class SearchToken(enum.Enum):
+    Not = enum.auto()
+    Or = enum.auto()
+    And = enum.auto()
+
+
+Token = Union[str, Phrase, SearchToken]
+TokenList = List[Token]
+
+
+def _is_stop_word(word: str) -> bool:
+    # TODO Pull these out of the dictionary:
+    #  https://github.com/postgres/postgres/blob/master/src/backend/snowball/stopwords/english.stop
+    return word in {"the", "a", "you", "me", "and", "but"}
+
+
+def _tokenize_query(query: str) -> TokenList:
+    """
+    Convert the user-supplied `query` into a TokenList, which can be translated into
+    some DB-specific syntax.
+
+    The following constructs are supported:
+
+    - phrase queries using "double quotes"
+    - case-insensitive `or` and `and` operators
+    - negation of a keyword via unary `-`
+    - unary hyphen to denote NOT e.g. 'include -exclude'
+
+    The following differs from websearch_to_tsquery:
+
+    - Stop words are not removed.
+    - Unclosed phrases are treated differently.
+
+    """
+    tokens: TokenList = []
+
+    # Find phrases.
+    in_phrase = False
+    parts = deque(query.split('"'))
+    for i, part in enumerate(parts):
+        # The contents inside double quotes is treated as a phrase, a trailing
+        # double quote is not implied.
+        in_phrase = bool(i % 2) and i != (len(parts) - 1)
+
+        # Pull out the individual words, discarding any non-word characters.
+        words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE))
+
+        # Phrases have simplified handling of words.
+        if in_phrase:
+            # Skip stop words.
+            phrase = [word for word in words if not _is_stop_word(word)]
+
+            # Consecutive words are implicitly ANDed together.
+            if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or):
+                tokens.append(SearchToken.And)
+
+            # Add the phrase.
+            tokens.append(Phrase(phrase))
+            continue
+
+        # Otherwise, not in a phrase.
+        while words:
+            word = words.popleft()
+
+            if word.startswith("-"):
+                tokens.append(SearchToken.Not)
+
+                # If there's more word, put it back to be processed again.
+                word = word[1:]
+                if word:
+                    words.appendleft(word)
+            elif word.lower() == "or":
+                tokens.append(SearchToken.Or)
+            else:
+                # Skip stop words.
+                if _is_stop_word(word):
+                    continue
+
+                # Consecutive words are implicitly ANDed together.
+                if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or):
+                    tokens.append(SearchToken.And)
+
+                # Add the search term.
+                tokens.append(word)
+
+    return tokens
+
+
+def _tokens_to_sqlite_match_query(tokens: TokenList) -> str:
+    """
+    Convert the list of tokens to a string suitable for passing to sqlite's MATCH.
+    Assume sqlite was compiled with enhanced query syntax.
+
+    Ref: https://www.sqlite.org/fts3.html#full_text_index_queries
     """
+    match_query = []
+    for token in tokens:
+        if isinstance(token, str):
+            match_query.append(token)
+        elif isinstance(token, Phrase):
+            match_query.append('"' + " ".join(token.phrase) + '"')
+        elif token == SearchToken.Not:
+            # TODO: SQLite treats NOT as a *binary* operator. Hopefully a search
+            # term has already been added before this.
+            match_query.append(" NOT ")
+        elif token == SearchToken.Or:
+            match_query.append(" OR ")
+        elif token == SearchToken.And:
+            match_query.append(" AND ")
+        else:
+            raise ValueError(f"unknown token {token}")
+
+    return "".join(match_query)
 
-    # Pull out the individual words, discarding any non-word characters.
-    results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
 
-    if isinstance(database_engine, PostgresEngine):
-        return " & ".join(result + ":*" for result in results)
-    elif isinstance(database_engine, Sqlite3Engine):
-        return " & ".join(result + "*" for result in results)
-    else:
-        # This should be unreachable.
-        raise Exception("Unrecognized database engine")
+def _parse_query_for_sqlite(search_term: str) -> str:
+    """Takes a plain unicode string from the user and converts it into a form
+    that can be passed to sqllite's matchinfo().
+    """
+    return _tokens_to_sqlite_match_query(_tokenize_query(search_term))
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index d8c0f64d9a..9bf74bbf59 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -170,6 +170,22 @@ class PostgresEngine(
         """Do we support the `RETURNING` clause in insert/update/delete?"""
         return True
 
+    @property
+    def tsquery_func(self) -> str:
+        """
+        Selects a tsquery_* func to use.
+
+        Ref: https://www.postgresql.org/docs/current/textsearch-controls.html
+
+        Returns:
+            The function name.
+        """
+        # Postgres 11 added support for websearch_to_tsquery.
+        assert self._version is not None
+        if self._version >= 110000:
+            return "websearch_to_tsquery"
+        return "plainto_tsquery"
+
     def is_deadlock(self, error: Exception) -> bool:
         if isinstance(error, psycopg2.DatabaseError):
             # https://www.postgresql.org/docs/current/static/errcodes-appendix.html
diff --git a/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py b/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py
new file mode 100644
index 0000000000..3de0a709eb
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py
@@ -0,0 +1,62 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+
+from synapse.storage.engines import BaseDatabaseEngine, Sqlite3Engine
+from synapse.storage.types import Cursor
+
+
+def run_create(cur: Cursor, database_engine: BaseDatabaseEngine) -> None:
+    """
+    Upgrade the event_search table to use the porter tokenizer if it isn't already
+
+    Applies only for sqlite.
+    """
+    if not isinstance(database_engine, Sqlite3Engine):
+        return
+
+    # Rebuild the table event_search table with tokenize=porter configured.
+    cur.execute("DROP TABLE event_search")
+    cur.execute(
+        """
+        CREATE VIRTUAL TABLE event_search
+        USING fts4 (tokenize=porter, event_id, room_id, sender, key, value )
+        """
+    )
+
+    # Re-run the background job to re-populate the event_search table.
+    cur.execute("SELECT MIN(stream_ordering) FROM events")
+    row = cur.fetchone()
+    min_stream_id = row[0]
+
+    # If there are not any events, nothing to do.
+    if min_stream_id is None:
+        return
+
+    cur.execute("SELECT MAX(stream_ordering) FROM events")
+    row = cur.fetchone()
+    max_stream_id = row[0]
+
+    progress = {
+        "target_min_stream_id_inclusive": min_stream_id,
+        "max_stream_id_exclusive": max_stream_id + 1,
+    }
+    progress_json = json.dumps(progress)
+
+    sql = """
+    INSERT into background_updates (ordering, update_name, progress_json)
+    VALUES (?, ?, ?)
+    """
+
+    cur.execute(sql, (7310, "event_search", progress_json))