diff --git a/changelog.d/11635.feature b/changelog.d/11635.feature
new file mode 100644
index 0000000000..94c8a83212
--- /dev/null
+++ b/changelog.d/11635.feature
@@ -0,0 +1 @@
+Allow use of postgres and sqllite full-text search operators in search queries.
\ No newline at end of file
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 1b79acf955..a89fc54c2c 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -11,10 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import enum
import logging
import re
-from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
+from collections import deque
+from dataclasses import dataclass
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
import attr
@@ -27,7 +39,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
if TYPE_CHECKING:
@@ -421,8 +433,6 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
clauses = []
- search_query = _parse_query(self.database_engine, search_term)
-
args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms.
@@ -444,20 +454,24 @@ class SearchStore(SearchBackgroundUpdateStore):
count_clauses = clauses
if isinstance(self.database_engine, PostgresEngine):
+ search_query = search_term
+ tsquery_func = self.database_engine.tsquery_func
sql = (
- "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank,"
+ f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank,"
" room_id, event_id"
" FROM event_search"
- " WHERE vector @@ to_tsquery('english', ?)"
+ f" WHERE vector @@ {tsquery_func}('english', ?)"
)
args = [search_query, search_query] + args
count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
- " WHERE vector @@ to_tsquery('english', ?)"
+ f" WHERE vector @@ {tsquery_func}('english', ?)"
)
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
+ search_query = _parse_query_for_sqlite(search_term)
+
sql = (
"SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
" FROM event_search"
@@ -469,7 +483,7 @@ class SearchStore(SearchBackgroundUpdateStore):
"SELECT room_id, count(*) as count FROM event_search"
" WHERE value MATCH ?"
)
- count_args = [search_term] + count_args
+ count_args = [search_query] + count_args
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@@ -501,7 +515,9 @@ class SearchStore(SearchBackgroundUpdateStore):
highlights = None
if isinstance(self.database_engine, PostgresEngine):
- highlights = await self._find_highlights_in_postgres(search_query, events)
+ highlights = await self._find_highlights_in_postgres(
+ search_query, events, tsquery_func
+ )
count_sql += " GROUP BY room_id"
@@ -510,7 +526,6 @@ class SearchStore(SearchBackgroundUpdateStore):
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
-
return {
"results": [
{"event": event_map[r["event_id"]], "rank": r["rank"]}
@@ -542,9 +557,6 @@ class SearchStore(SearchBackgroundUpdateStore):
Each match as a dictionary.
"""
clauses = []
-
- search_query = _parse_query(self.database_engine, search_term)
-
args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms.
@@ -582,20 +594,23 @@ class SearchStore(SearchBackgroundUpdateStore):
args.extend([origin_server_ts, origin_server_ts, stream])
if isinstance(self.database_engine, PostgresEngine):
+ search_query = search_term
+ tsquery_func = self.database_engine.tsquery_func
sql = (
- "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank,"
+ f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank,"
" origin_server_ts, stream_ordering, room_id, event_id"
" FROM event_search"
- " WHERE vector @@ to_tsquery('english', ?) AND "
+ f" WHERE vector @@ {tsquery_func}('english', ?) AND "
)
args = [search_query, search_query] + args
count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
- " WHERE vector @@ to_tsquery('english', ?) AND "
+ f" WHERE vector @@ {tsquery_func}('english', ?) AND "
)
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
+
# We use CROSS JOIN here to ensure we use the right indexes.
# https://sqlite.org/optoverview.html#crossjoin
#
@@ -614,13 +629,14 @@ class SearchStore(SearchBackgroundUpdateStore):
" CROSS JOIN events USING (event_id)"
" WHERE "
)
+ search_query = _parse_query_for_sqlite(search_term)
args = [search_query] + args
count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
" WHERE value MATCH ? AND "
)
- count_args = [search_term] + count_args
+ count_args = [search_query] + count_args
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@@ -660,7 +676,9 @@ class SearchStore(SearchBackgroundUpdateStore):
highlights = None
if isinstance(self.database_engine, PostgresEngine):
- highlights = await self._find_highlights_in_postgres(search_query, events)
+ highlights = await self._find_highlights_in_postgres(
+ search_query, events, tsquery_func
+ )
count_sql += " GROUP BY room_id"
@@ -686,7 +704,7 @@ class SearchStore(SearchBackgroundUpdateStore):
}
async def _find_highlights_in_postgres(
- self, search_query: str, events: List[EventBase]
+ self, search_query: str, events: List[EventBase], tsquery_func: str
) -> Set[str]:
"""Given a list of events and a search term, return a list of words
that match from the content of the event.
@@ -697,6 +715,7 @@ class SearchStore(SearchBackgroundUpdateStore):
Args:
search_query
events: A list of events
+ tsquery_func: The tsquery_* function to use when making queries
Returns:
A set of strings.
@@ -729,7 +748,7 @@ class SearchStore(SearchBackgroundUpdateStore):
while stop_sel in value:
stop_sel += ">"
- query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
+ query = f"SELECT ts_headline(?, {tsquery_func}('english', ?), %s)" % (
_to_postgres_options(
{
"StartSel": start_sel,
@@ -760,20 +779,128 @@ def _to_postgres_options(options_dict: JsonDict) -> str:
return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
-def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
- """Takes a plain unicode string from the user and converts it into a form
- that can be passed to database.
- We use this so that we can add prefix matching, which isn't something
- that is supported by default.
+@dataclass
+class Phrase:
+ phrase: List[str]
+
+
+class SearchToken(enum.Enum):
+ Not = enum.auto()
+ Or = enum.auto()
+ And = enum.auto()
+
+
+Token = Union[str, Phrase, SearchToken]
+TokenList = List[Token]
+
+
+def _is_stop_word(word: str) -> bool:
+ # TODO Pull these out of the dictionary:
+ # https://github.com/postgres/postgres/blob/master/src/backend/snowball/stopwords/english.stop
+ return word in {"the", "a", "you", "me", "and", "but"}
+
+
+def _tokenize_query(query: str) -> TokenList:
+ """
+ Convert the user-supplied `query` into a TokenList, which can be translated into
+ some DB-specific syntax.
+
+ The following constructs are supported:
+
+ - phrase queries using "double quotes"
+ - case-insensitive `or` and `and` operators
+ - negation of a keyword via unary `-`
+ - unary hyphen to denote NOT e.g. 'include -exclude'
+
+ The following differs from websearch_to_tsquery:
+
+ - Stop words are not removed.
+ - Unclosed phrases are treated differently.
+
+ """
+ tokens: TokenList = []
+
+ # Find phrases.
+ in_phrase = False
+ parts = deque(query.split('"'))
+ for i, part in enumerate(parts):
+ # The contents inside double quotes is treated as a phrase, a trailing
+ # double quote is not implied.
+ in_phrase = bool(i % 2) and i != (len(parts) - 1)
+
+ # Pull out the individual words, discarding any non-word characters.
+ words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE))
+
+ # Phrases have simplified handling of words.
+ if in_phrase:
+ # Skip stop words.
+ phrase = [word for word in words if not _is_stop_word(word)]
+
+ # Consecutive words are implicitly ANDed together.
+ if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or):
+ tokens.append(SearchToken.And)
+
+ # Add the phrase.
+ tokens.append(Phrase(phrase))
+ continue
+
+ # Otherwise, not in a phrase.
+ while words:
+ word = words.popleft()
+
+ if word.startswith("-"):
+ tokens.append(SearchToken.Not)
+
+ # If there's more word, put it back to be processed again.
+ word = word[1:]
+ if word:
+ words.appendleft(word)
+ elif word.lower() == "or":
+ tokens.append(SearchToken.Or)
+ else:
+ # Skip stop words.
+ if _is_stop_word(word):
+ continue
+
+ # Consecutive words are implicitly ANDed together.
+ if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or):
+ tokens.append(SearchToken.And)
+
+ # Add the search term.
+ tokens.append(word)
+
+ return tokens
+
+
+def _tokens_to_sqlite_match_query(tokens: TokenList) -> str:
+ """
+ Convert the list of tokens to a string suitable for passing to sqlite's MATCH.
+ Assume sqlite was compiled with enhanced query syntax.
+
+ Ref: https://www.sqlite.org/fts3.html#full_text_index_queries
"""
+ match_query = []
+ for token in tokens:
+ if isinstance(token, str):
+ match_query.append(token)
+ elif isinstance(token, Phrase):
+ match_query.append('"' + " ".join(token.phrase) + '"')
+ elif token == SearchToken.Not:
+ # TODO: SQLite treats NOT as a *binary* operator. Hopefully a search
+ # term has already been added before this.
+ match_query.append(" NOT ")
+ elif token == SearchToken.Or:
+ match_query.append(" OR ")
+ elif token == SearchToken.And:
+ match_query.append(" AND ")
+ else:
+ raise ValueError(f"unknown token {token}")
+
+ return "".join(match_query)
- # Pull out the individual words, discarding any non-word characters.
- results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
- if isinstance(database_engine, PostgresEngine):
- return " & ".join(result + ":*" for result in results)
- elif isinstance(database_engine, Sqlite3Engine):
- return " & ".join(result + "*" for result in results)
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
+def _parse_query_for_sqlite(search_term: str) -> str:
+ """Takes a plain unicode string from the user and converts it into a form
+ that can be passed to sqllite's matchinfo().
+ """
+ return _tokens_to_sqlite_match_query(_tokenize_query(search_term))
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index d8c0f64d9a..9bf74bbf59 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -170,6 +170,22 @@ class PostgresEngine(
"""Do we support the `RETURNING` clause in insert/update/delete?"""
return True
+ @property
+ def tsquery_func(self) -> str:
+ """
+ Selects a tsquery_* func to use.
+
+ Ref: https://www.postgresql.org/docs/current/textsearch-controls.html
+
+ Returns:
+ The function name.
+ """
+ # Postgres 11 added support for websearch_to_tsquery.
+ assert self._version is not None
+ if self._version >= 110000:
+ return "websearch_to_tsquery"
+ return "plainto_tsquery"
+
def is_deadlock(self, error: Exception) -> bool:
if isinstance(error, psycopg2.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
diff --git a/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py b/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py
new file mode 100644
index 0000000000..3de0a709eb
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py
@@ -0,0 +1,62 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+
+from synapse.storage.engines import BaseDatabaseEngine, Sqlite3Engine
+from synapse.storage.types import Cursor
+
+
+def run_create(cur: Cursor, database_engine: BaseDatabaseEngine) -> None:
+ """
+ Upgrade the event_search table to use the porter tokenizer if it isn't already
+
+ Applies only for sqlite.
+ """
+ if not isinstance(database_engine, Sqlite3Engine):
+ return
+
+ # Rebuild the table event_search table with tokenize=porter configured.
+ cur.execute("DROP TABLE event_search")
+ cur.execute(
+ """
+ CREATE VIRTUAL TABLE event_search
+ USING fts4 (tokenize=porter, event_id, room_id, sender, key, value )
+ """
+ )
+
+ # Re-run the background job to re-populate the event_search table.
+ cur.execute("SELECT MIN(stream_ordering) FROM events")
+ row = cur.fetchone()
+ min_stream_id = row[0]
+
+ # If there are not any events, nothing to do.
+ if min_stream_id is None:
+ return
+
+ cur.execute("SELECT MAX(stream_ordering) FROM events")
+ row = cur.fetchone()
+ max_stream_id = row[0]
+
+ progress = {
+ "target_min_stream_id_inclusive": min_stream_id,
+ "max_stream_id_exclusive": max_stream_id + 1,
+ }
+ progress_json = json.dumps(progress)
+
+ sql = """
+ INSERT into background_updates (ordering, update_name, progress_json)
+ VALUES (?, ?, ?)
+ """
+
+ cur.execute(sql, (7310, "event_search", progress_json))
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index e747c6b50e..9ddc19900a 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -12,11 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Tuple, Union
+from unittest.case import SkipTest
+from unittest.mock import PropertyMock, patch
+
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.storage.databases.main import DataStore
+from synapse.storage.databases.main.search import Phrase, SearchToken, _tokenize_query
from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines.sqlite import Sqlite3Engine
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase, skip_unless
from tests.utils import USE_POSTGRES_FOR_TESTS
@@ -187,3 +198,205 @@ class EventSearchInsertionTest(HomeserverTestCase):
),
)
self.assertCountEqual(values, ["hi", "2"])
+
+
+class MessageSearchTest(HomeserverTestCase):
+ """
+ Check message search.
+
+ A powerful way to check the behaviour is to run the following in Postgres >= 11:
+
+ # SELECT websearch_to_tsquery('english', <your string>);
+
+ The result can be compared to the tokenized version for SQLite and Postgres < 11.
+
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ PHRASE = "the quick brown fox jumps over the lazy dog"
+
+ # Each entry is a search query, followed by either a boolean of whether it is
+ # in the phrase OR a tuple of booleans: whether it matches using websearch
+ # and using plain search.
+ COMMON_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [
+ ("nope", False),
+ ("brown", True),
+ ("quick brown", True),
+ ("brown quick", True),
+ ("quick \t brown", True),
+ ("jump", True),
+ ("brown nope", False),
+ ('"brown quick"', (False, True)),
+ ('"jumps over"', True),
+ ('"quick fox"', (False, True)),
+ ("nope OR doublenope", False),
+ ("furphy OR fox", (True, False)),
+ ("fox -nope", (True, False)),
+ ("fox -brown", (False, True)),
+ ('"fox" quick', True),
+ ('"fox quick', True),
+ ('"quick brown', True),
+ ('" quick "', True),
+ ('" nope"', False),
+ ]
+ # TODO Test non-ASCII cases.
+
+ # Case that fail on SQLite.
+ POSTGRES_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [
+ # SQLite treats NOT as a binary operator.
+ ("- fox", (False, True)),
+ ("- nope", (True, False)),
+ ('"-fox quick', (False, True)),
+ # PostgreSQL skips stop words.
+ ('"the quick brown"', True),
+ ('"over lazy"', True),
+ ]
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ # Register a user and create a room, create some messages
+ self.register_user("alice", "password")
+ self.access_token = self.login("alice", "password")
+ self.room_id = self.helper.create_room_as("alice", tok=self.access_token)
+
+ # Send the phrase as a message and check it was created
+ response = self.helper.send(self.room_id, self.PHRASE, tok=self.access_token)
+ self.assertIn("event_id", response)
+
+ def test_tokenize_query(self) -> None:
+ """Test the custom logic to tokenize a user's query."""
+ cases = (
+ ("brown", ["brown"]),
+ ("quick brown", ["quick", SearchToken.And, "brown"]),
+ ("quick \t brown", ["quick", SearchToken.And, "brown"]),
+ ('"brown quick"', [Phrase(["brown", "quick"])]),
+ ("furphy OR fox", ["furphy", SearchToken.Or, "fox"]),
+ ("fox -brown", ["fox", SearchToken.Not, "brown"]),
+ ("- fox", [SearchToken.Not, "fox"]),
+ ('"fox" quick', [Phrase(["fox"]), SearchToken.And, "quick"]),
+ # No trailing double quoe.
+ ('"fox quick', ["fox", SearchToken.And, "quick"]),
+ ('"-fox quick', [SearchToken.Not, "fox", SearchToken.And, "quick"]),
+ ('" quick "', [Phrase(["quick"])]),
+ (
+ 'q"uick brow"n',
+ [
+ "q",
+ SearchToken.And,
+ Phrase(["uick", "brow"]),
+ SearchToken.And,
+ "n",
+ ],
+ ),
+ (
+ '-"quick brown"',
+ [SearchToken.Not, Phrase(["quick", "brown"])],
+ ),
+ )
+
+ for query, expected in cases:
+ tokenized = _tokenize_query(query)
+ self.assertEqual(
+ tokenized, expected, f"{tokenized} != {expected} for {query}"
+ )
+
+ def _check_test_cases(
+ self,
+ store: DataStore,
+ cases: List[Tuple[str, Union[bool, Tuple[bool, bool]]]],
+ index=0,
+ ) -> None:
+ # Run all the test cases versus search_msgs
+ for query, expect_to_contain in cases:
+ if isinstance(expect_to_contain, tuple):
+ expect_to_contain = expect_to_contain[index]
+
+ result = self.get_success(
+ store.search_msgs([self.room_id], query, ["content.body"])
+ )
+ self.assertEquals(
+ result["count"],
+ 1 if expect_to_contain else 0,
+ f"expected '{query}' to match '{self.PHRASE}'"
+ if expect_to_contain
+ else f"'{query}' unexpectedly matched '{self.PHRASE}'",
+ )
+ self.assertEquals(
+ len(result["results"]),
+ 1 if expect_to_contain else 0,
+ "results array length should match count",
+ )
+
+ # Run them again versus search_rooms
+ for query, expect_to_contain in cases:
+ if isinstance(expect_to_contain, tuple):
+ expect_to_contain = expect_to_contain[index]
+
+ result = self.get_success(
+ store.search_rooms([self.room_id], query, ["content.body"], 10)
+ )
+ self.assertEquals(
+ result["count"],
+ 1 if expect_to_contain else 0,
+ f"expected '{query}' to match '{self.PHRASE}'"
+ if expect_to_contain
+ else f"'{query}' unexpectedly matched '{self.PHRASE}'",
+ )
+ self.assertEquals(
+ len(result["results"]),
+ 1 if expect_to_contain else 0,
+ "results array length should match count",
+ )
+
+ def test_postgres_web_search_for_phrase(self):
+ """
+ Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery.
+ This test is skipped unless the postgres instance supports websearch_to_tsquery.
+ """
+
+ store = self.hs.get_datastores().main
+ if not isinstance(store.database_engine, PostgresEngine):
+ raise SkipTest("Test only applies when postgres is used as the database")
+
+ if store.database_engine.tsquery_func != "websearch_to_tsquery":
+ raise SkipTest(
+ "Test only applies when postgres supporting websearch_to_tsquery is used as the database"
+ )
+
+ self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES, index=0)
+
+ def test_postgres_non_web_search_for_phrase(self):
+ """
+ Test postgres searching for phrases without using web search, which is used when websearch_to_tsquery isn't
+ supported by the current postgres version.
+ """
+
+ store = self.hs.get_datastores().main
+ if not isinstance(store.database_engine, PostgresEngine):
+ raise SkipTest("Test only applies when postgres is used as the database")
+
+ # Patch supports_websearch_to_tsquery to always return False to ensure we're testing the plainto_tsquery path.
+ with patch(
+ "synapse.storage.engines.postgres.PostgresEngine.tsquery_func",
+ new_callable=PropertyMock,
+ ) as supports_websearch_to_tsquery:
+ supports_websearch_to_tsquery.return_value = "plainto_tsquery"
+ self._check_test_cases(
+ store, self.COMMON_CASES + self.POSTGRES_CASES, index=1
+ )
+
+ def test_sqlite_search(self):
+ """
+ Test sqlite searching for phrases.
+ """
+ store = self.hs.get_datastores().main
+ if not isinstance(store.database_engine, Sqlite3Engine):
+ raise SkipTest("Test only applies when sqlite is used as the database")
+
+ self._check_test_cases(store, self.COMMON_CASES, index=0)
|