summary refs log tree commit diff
path: root/synapse/storage/databases/main/search.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/search.py')
-rw-r--r--synapse/storage/databases/main/search.py374
1 files changed, 253 insertions, 121 deletions
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 1b79acf955..3fe433f66c 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -11,10 +11,22 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import enum
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
+from collections import deque
+from dataclasses import dataclass
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    Union,
+)
 
 import attr
 
@@ -27,7 +39,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -68,11 +80,11 @@ class SearchWorkerStore(SQLBaseStore):
         if not self.hs.config.server.enable_search:
             return
         if isinstance(self.database_engine, PostgresEngine):
-            sql = (
-                "INSERT INTO event_search"
-                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
-                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
-            )
+            sql = """
+            INSERT INTO event_search
+            (event_id, room_id, key, vector, stream_ordering, origin_server_ts)
+            VALUES (?,?,?,to_tsvector('english', ?),?,?)
+            """
 
             args1 = (
                 (
@@ -89,20 +101,20 @@ class SearchWorkerStore(SQLBaseStore):
             txn.execute_batch(sql, args1)
 
         elif isinstance(self.database_engine, Sqlite3Engine):
-            sql = (
-                "INSERT INTO event_search (event_id, room_id, key, value)"
-                " VALUES (?,?,?,?)"
-            )
-            args2 = (
-                (
-                    entry.event_id,
-                    entry.room_id,
-                    entry.key,
-                    _clean_value_for_search(entry.value),
-                )
-                for entry in entries
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="event_search",
+                keys=("event_id", "room_id", "key", "value"),
+                values=(
+                    (
+                        entry.event_id,
+                        entry.room_id,
+                        entry.key,
+                        _clean_value_for_search(entry.value),
+                    )
+                    for entry in entries
+                ),
             )
-            txn.execute_batch(sql, args2)
 
         else:
             # This should be unreachable.
@@ -150,15 +162,17 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
         TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
 
         def reindex_search_txn(txn: LoggingTransaction) -> int:
-            sql = (
-                "SELECT stream_ordering, event_id, room_id, type, json, "
-                " origin_server_ts FROM events"
-                " JOIN event_json USING (room_id, event_id)"
-                " WHERE ? <= stream_ordering AND stream_ordering < ?"
-                " AND (%s)"
-                " ORDER BY stream_ordering DESC"
-                " LIMIT ?"
-            ) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),)
+            sql = """
+            SELECT stream_ordering, event_id, room_id, type, json, origin_server_ts
+            FROM events
+            JOIN event_json USING (room_id, event_id)
+            WHERE ? <= stream_ordering AND stream_ordering < ?
+            AND (%s)
+            ORDER BY stream_ordering DESC
+            LIMIT ?
+            """ % (
+                " OR ".join("type = '%s'" % (t,) for t in TYPES),
+            )
 
             txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
 
@@ -272,8 +286,10 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
                 try:
                     c.execute(
-                        "CREATE INDEX CONCURRENTLY event_search_fts_idx"
-                        " ON event_search USING GIN (vector)"
+                        """
+                        CREATE INDEX CONCURRENTLY event_search_fts_idx
+                        ON event_search USING GIN (vector)
+                        """
                     )
                 except psycopg2.ProgrammingError as e:
                     logger.warning(
@@ -311,12 +327,16 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
                 # We create with NULLS FIRST so that when we search *backwards*
                 # we get the ones with non null origin_server_ts *first*
                 c.execute(
-                    "CREATE INDEX CONCURRENTLY event_search_room_order ON event_search("
-                    "room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
+                    """
+                    CREATE INDEX CONCURRENTLY event_search_room_order
+                    ON event_search(room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)
+                    """
                 )
                 c.execute(
-                    "CREATE INDEX CONCURRENTLY event_search_order ON event_search("
-                    "origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
+                    """
+                    CREATE INDEX CONCURRENTLY event_search_order
+                    ON event_search(origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)
+                    """
                 )
                 conn.set_session(autocommit=False)
 
@@ -333,14 +353,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
             )
 
         def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
-            sql = (
-                "UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
-                " origin_server_ts = e.origin_server_ts"
-                " FROM events AS e"
-                " WHERE e.event_id = es.event_id"
-                " AND ? <= e.stream_ordering AND e.stream_ordering < ?"
-                " RETURNING es.stream_ordering"
-            )
+            sql = """
+            UPDATE event_search AS es
+            SET stream_ordering = e.stream_ordering, origin_server_ts = e.origin_server_ts
+            FROM events AS e
+            WHERE e.event_id = es.event_id
+            AND ? <= e.stream_ordering AND e.stream_ordering < ?
+            RETURNING es.stream_ordering
+            """
 
             min_stream_id = max_stream_id - batch_size
             txn.execute(sql, (min_stream_id, max_stream_id))
@@ -421,8 +441,6 @@ class SearchStore(SearchBackgroundUpdateStore):
         """
         clauses = []
 
-        search_query = _parse_query(self.database_engine, search_term)
-
         args: List[Any] = []
 
         # Make sure we don't explode because the person is in too many rooms.
@@ -444,32 +462,35 @@ class SearchStore(SearchBackgroundUpdateStore):
         count_clauses = clauses
 
         if isinstance(self.database_engine, PostgresEngine):
-            sql = (
-                "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank,"
-                " room_id, event_id"
-                " FROM event_search"
-                " WHERE vector @@ to_tsquery('english', ?)"
-            )
+            search_query = search_term
+            sql = """
+            SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) AS rank,
+            room_id, event_id
+            FROM event_search
+            WHERE vector @@  websearch_to_tsquery('english', ?)
+            """
             args = [search_query, search_query] + args
 
-            count_sql = (
-                "SELECT room_id, count(*) as count FROM event_search"
-                " WHERE vector @@ to_tsquery('english', ?)"
-            )
+            count_sql = """
+            SELECT room_id, count(*) as count FROM event_search
+            WHERE vector @@ websearch_to_tsquery('english', ?)
+            """
             count_args = [search_query] + count_args
         elif isinstance(self.database_engine, Sqlite3Engine):
-            sql = (
-                "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
-                " FROM event_search"
-                " WHERE value MATCH ?"
-            )
+            search_query = _parse_query_for_sqlite(search_term)
+
+            sql = """
+            SELECT rank(matchinfo(event_search)) as rank, room_id, event_id
+            FROM event_search
+            WHERE value MATCH ?
+            """
             args = [search_query] + args
 
-            count_sql = (
-                "SELECT room_id, count(*) as count FROM event_search"
-                " WHERE value MATCH ?"
-            )
-            count_args = [search_term] + count_args
+            count_sql = """
+            SELECT room_id, count(*) as count FROM event_search
+            WHERE value MATCH ?
+            """
+            count_args = [search_query] + count_args
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
@@ -510,7 +531,6 @@ class SearchStore(SearchBackgroundUpdateStore):
         )
 
         count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
-
         return {
             "results": [
                 {"event": event_map[r["event_id"]], "rank": r["rank"]}
@@ -542,9 +562,6 @@ class SearchStore(SearchBackgroundUpdateStore):
             Each match as a dictionary.
         """
         clauses = []
-
-        search_query = _parse_query(self.database_engine, search_term)
-
         args: List[Any] = []
 
         # Make sure we don't explode because the person is in too many rooms.
@@ -576,26 +593,29 @@ class SearchStore(SearchBackgroundUpdateStore):
                 raise SynapseError(400, "Invalid pagination token")
 
             clauses.append(
-                "(origin_server_ts < ?"
-                " OR (origin_server_ts = ? AND stream_ordering < ?))"
+                """
+                (origin_server_ts < ? OR (origin_server_ts = ? AND stream_ordering < ?))
+                """
             )
             args.extend([origin_server_ts, origin_server_ts, stream])
 
         if isinstance(self.database_engine, PostgresEngine):
-            sql = (
-                "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank,"
-                " origin_server_ts, stream_ordering, room_id, event_id"
-                " FROM event_search"
-                " WHERE vector @@ to_tsquery('english', ?) AND "
-            )
+            search_query = search_term
+            sql = """
+            SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
+            origin_server_ts, stream_ordering, room_id, event_id
+            FROM event_search
+            WHERE vector @@ websearch_to_tsquery('english', ?) AND
+            """
             args = [search_query, search_query] + args
 
-            count_sql = (
-                "SELECT room_id, count(*) as count FROM event_search"
-                " WHERE vector @@ to_tsquery('english', ?) AND "
-            )
+            count_sql = """
+            SELECT room_id, count(*) as count FROM event_search
+            WHERE vector @@ websearch_to_tsquery('english', ?) AND
+            """
             count_args = [search_query] + count_args
         elif isinstance(self.database_engine, Sqlite3Engine):
+
             # We use CROSS JOIN here to ensure we use the right indexes.
             # https://sqlite.org/optoverview.html#crossjoin
             #
@@ -604,23 +624,25 @@ class SearchStore(SearchBackgroundUpdateStore):
             # in the events table to get the topological ordering. We need
             # to use the indexes in this order because sqlite refuses to
             # MATCH unless it uses the full text search index
-            sql = (
-                "SELECT rank(matchinfo) as rank, room_id, event_id,"
-                " origin_server_ts, stream_ordering"
-                " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo"
-                " FROM event_search"
-                " WHERE value MATCH ?"
-                " )"
-                " CROSS JOIN events USING (event_id)"
-                " WHERE "
+            sql = """
+            SELECT
+                rank(matchinfo) as rank, room_id, event_id, origin_server_ts, stream_ordering
+            FROM (
+                SELECT key, event_id, matchinfo(event_search) as matchinfo
+                FROM event_search
+                WHERE value MATCH ?
             )
+            CROSS JOIN events USING (event_id)
+            WHERE
+            """
+            search_query = _parse_query_for_sqlite(search_term)
             args = [search_query] + args
 
-            count_sql = (
-                "SELECT room_id, count(*) as count FROM event_search"
-                " WHERE value MATCH ? AND "
-            )
-            count_args = [search_term] + count_args
+            count_sql = """
+            SELECT room_id, count(*) as count FROM event_search
+            WHERE value MATCH ? AND
+            """
+            count_args = [search_query] + count_args
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
@@ -631,10 +653,10 @@ class SearchStore(SearchBackgroundUpdateStore):
         # We add an arbitrary limit here to ensure we don't try to pull the
         # entire table from the database.
         if isinstance(self.database_engine, PostgresEngine):
-            sql += (
-                " ORDER BY origin_server_ts DESC NULLS LAST,"
-                " stream_ordering DESC NULLS LAST LIMIT ?"
-            )
+            sql += """
+            ORDER BY origin_server_ts DESC NULLS LAST, stream_ordering DESC NULLS LAST
+            LIMIT ?
+            """
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
         else:
@@ -729,13 +751,16 @@ class SearchStore(SearchBackgroundUpdateStore):
                 while stop_sel in value:
                     stop_sel += ">"
 
-                query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
-                    _to_postgres_options(
-                        {
-                            "StartSel": start_sel,
-                            "StopSel": stop_sel,
-                            "MaxFragments": "50",
-                        }
+                query = (
+                    "SELECT ts_headline(?, websearch_to_tsquery('english', ?), %s)"
+                    % (
+                        _to_postgres_options(
+                            {
+                                "StartSel": start_sel,
+                                "StopSel": stop_sel,
+                                "MaxFragments": "50",
+                            }
+                        )
                     )
                 )
                 txn.execute(query, (value, search_query))
@@ -760,20 +785,127 @@ def _to_postgres_options(options_dict: JsonDict) -> str:
     return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
 
 
-def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
-    """Takes a plain unicode string from the user and converts it into a form
-    that can be passed to database.
-    We use this so that we can add prefix matching, which isn't something
-    that is supported by default.
+@dataclass
+class Phrase:
+    phrase: List[str]
+
+
+class SearchToken(enum.Enum):
+    Not = enum.auto()
+    Or = enum.auto()
+    And = enum.auto()
+
+
+Token = Union[str, Phrase, SearchToken]
+TokenList = List[Token]
+
+
+def _is_stop_word(word: str) -> bool:
+    # TODO Pull these out of the dictionary:
+    #  https://github.com/postgres/postgres/blob/master/src/backend/snowball/stopwords/english.stop
+    return word in {"the", "a", "you", "me", "and", "but"}
+
+
+def _tokenize_query(query: str) -> TokenList:
     """
+    Convert the user-supplied `query` into a TokenList, which can be translated into
+    some DB-specific syntax.
+
+    The following constructs are supported:
+
+    - phrase queries using "double quotes"
+    - case-insensitive `or` and `and` operators
+    - negation of a keyword via unary `-`
+    - unary hyphen to denote NOT e.g. 'include -exclude'
 
-    # Pull out the individual words, discarding any non-word characters.
-    results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+    The following differs from websearch_to_tsquery:
 
-    if isinstance(database_engine, PostgresEngine):
-        return " & ".join(result + ":*" for result in results)
-    elif isinstance(database_engine, Sqlite3Engine):
-        return " & ".join(result + "*" for result in results)
-    else:
-        # This should be unreachable.
-        raise Exception("Unrecognized database engine")
+    - Stop words are not removed.
+    - Unclosed phrases are treated differently.
+
+    """
+    tokens: TokenList = []
+
+    # Find phrases.
+    in_phrase = False
+    parts = deque(query.split('"'))
+    for i, part in enumerate(parts):
+        # The contents inside double quotes is treated as a phrase.
+        in_phrase = bool(i % 2)
+
+        # Pull out the individual words, discarding any non-word characters.
+        words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE))
+
+        # Phrases have simplified handling of words.
+        if in_phrase:
+            # Skip stop words.
+            phrase = [word for word in words if not _is_stop_word(word)]
+
+            # Consecutive words are implicitly ANDed together.
+            if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or):
+                tokens.append(SearchToken.And)
+
+            # Add the phrase.
+            tokens.append(Phrase(phrase))
+            continue
+
+        # Otherwise, not in a phrase.
+        while words:
+            word = words.popleft()
+
+            if word.startswith("-"):
+                tokens.append(SearchToken.Not)
+
+                # If there's more word, put it back to be processed again.
+                word = word[1:]
+                if word:
+                    words.appendleft(word)
+            elif word.lower() == "or":
+                tokens.append(SearchToken.Or)
+            else:
+                # Skip stop words.
+                if _is_stop_word(word):
+                    continue
+
+                # Consecutive words are implicitly ANDed together.
+                if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or):
+                    tokens.append(SearchToken.And)
+
+                # Add the search term.
+                tokens.append(word)
+
+    return tokens
+
+
+def _tokens_to_sqlite_match_query(tokens: TokenList) -> str:
+    """
+    Convert the list of tokens to a string suitable for passing to sqlite's MATCH.
+    Assume sqlite was compiled with enhanced query syntax.
+
+    Ref: https://www.sqlite.org/fts3.html#full_text_index_queries
+    """
+    match_query = []
+    for token in tokens:
+        if isinstance(token, str):
+            match_query.append(token)
+        elif isinstance(token, Phrase):
+            match_query.append('"' + " ".join(token.phrase) + '"')
+        elif token == SearchToken.Not:
+            # TODO: SQLite treats NOT as a *binary* operator. Hopefully a search
+            # term has already been added before this.
+            match_query.append(" NOT ")
+        elif token == SearchToken.Or:
+            match_query.append(" OR ")
+        elif token == SearchToken.And:
+            match_query.append(" AND ")
+        else:
+            raise ValueError(f"unknown token {token}")
+
+    return "".join(match_query)
+
+
+def _parse_query_for_sqlite(search_term: str) -> str:
+    """Takes a plain unicode string from the user and converts it into a form
+    that can be passed to sqllite's matchinfo().
+    """
+    return _tokens_to_sqlite_match_query(_tokenize_query(search_term))