diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index f6e24b68d2..3fe433f66c 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -11,10 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import enum
import logging
import re
-from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
+from collections import deque
+from dataclasses import dataclass
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
import attr
@@ -27,7 +39,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
if TYPE_CHECKING:
@@ -68,11 +80,11 @@ class SearchWorkerStore(SQLBaseStore):
if not self.hs.config.server.enable_search:
return
if isinstance(self.database_engine, PostgresEngine):
- sql = (
- "INSERT INTO event_search"
- " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
- " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
- )
+ sql = """
+ INSERT INTO event_search
+ (event_id, room_id, key, vector, stream_ordering, origin_server_ts)
+ VALUES (?,?,?,to_tsvector('english', ?),?,?)
+ """
args1 = (
(
@@ -89,20 +101,20 @@ class SearchWorkerStore(SQLBaseStore):
txn.execute_batch(sql, args1)
elif isinstance(self.database_engine, Sqlite3Engine):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, value)"
- " VALUES (?,?,?,?)"
- )
- args2 = (
- (
- entry.event_id,
- entry.room_id,
- entry.key,
- _clean_value_for_search(entry.value),
- )
- for entry in entries
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="event_search",
+ keys=("event_id", "room_id", "key", "value"),
+ values=(
+ (
+ entry.event_id,
+ entry.room_id,
+ entry.key,
+ _clean_value_for_search(entry.value),
+ )
+ for entry in entries
+ ),
)
- txn.execute_batch(sql, args2)
else:
# This should be unreachable.
@@ -150,15 +162,17 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
def reindex_search_txn(txn: LoggingTransaction) -> int:
- sql = (
- "SELECT stream_ordering, event_id, room_id, type, json, "
- " origin_server_ts FROM events"
- " JOIN event_json USING (room_id, event_id)"
- " WHERE ? <= stream_ordering AND stream_ordering < ?"
- " AND (%s)"
- " ORDER BY stream_ordering DESC"
- " LIMIT ?"
- ) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),)
+ sql = """
+ SELECT stream_ordering, event_id, room_id, type, json, origin_server_ts
+ FROM events
+ JOIN event_json USING (room_id, event_id)
+ WHERE ? <= stream_ordering AND stream_ordering < ?
+ AND (%s)
+ ORDER BY stream_ordering DESC
+ LIMIT ?
+ """ % (
+ " OR ".join("type = '%s'" % (t,) for t in TYPES),
+ )
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
@@ -272,8 +286,10 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
try:
c.execute(
- "CREATE INDEX CONCURRENTLY event_search_fts_idx"
- " ON event_search USING GIN (vector)"
+ """
+ CREATE INDEX CONCURRENTLY event_search_fts_idx
+ ON event_search USING GIN (vector)
+ """
)
except psycopg2.ProgrammingError as e:
logger.warning(
@@ -311,12 +327,16 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# We create with NULLS FIRST so that when we search *backwards*
# we get the ones with non null origin_server_ts *first*
c.execute(
- "CREATE INDEX CONCURRENTLY event_search_room_order ON event_search("
- "room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
+ """
+ CREATE INDEX CONCURRENTLY event_search_room_order
+ ON event_search(room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)
+ """
)
c.execute(
- "CREATE INDEX CONCURRENTLY event_search_order ON event_search("
- "origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
+ """
+ CREATE INDEX CONCURRENTLY event_search_order
+ ON event_search(origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)
+ """
)
conn.set_session(autocommit=False)
@@ -333,14 +353,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
)
def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
- sql = (
- "UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
- " origin_server_ts = e.origin_server_ts"
- " FROM events AS e"
- " WHERE e.event_id = es.event_id"
- " AND ? <= e.stream_ordering AND e.stream_ordering < ?"
- " RETURNING es.stream_ordering"
- )
+ sql = """
+ UPDATE event_search AS es
+ SET stream_ordering = e.stream_ordering, origin_server_ts = e.origin_server_ts
+ FROM events AS e
+ WHERE e.event_id = es.event_id
+ AND ? <= e.stream_ordering AND e.stream_ordering < ?
+ RETURNING es.stream_ordering
+ """
min_stream_id = max_stream_id - batch_size
txn.execute(sql, (min_stream_id, max_stream_id))
@@ -421,8 +441,6 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
clauses = []
- search_query = _parse_query(self.database_engine, search_term)
-
args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms.
@@ -444,32 +462,35 @@ class SearchStore(SearchBackgroundUpdateStore):
count_clauses = clauses
if isinstance(self.database_engine, PostgresEngine):
- sql = (
- "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank,"
- " room_id, event_id"
- " FROM event_search"
- " WHERE vector @@ to_tsquery('english', ?)"
- )
+ search_query = search_term
+ sql = """
+ SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) AS rank,
+ room_id, event_id
+ FROM event_search
+ WHERE vector @@ websearch_to_tsquery('english', ?)
+ """
args = [search_query, search_query] + args
- count_sql = (
- "SELECT room_id, count(*) as count FROM event_search"
- " WHERE vector @@ to_tsquery('english', ?)"
- )
+ count_sql = """
+ SELECT room_id, count(*) as count FROM event_search
+ WHERE vector @@ websearch_to_tsquery('english', ?)
+ """
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
- sql = (
- "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
- " FROM event_search"
- " WHERE value MATCH ?"
- )
+ search_query = _parse_query_for_sqlite(search_term)
+
+ sql = """
+ SELECT rank(matchinfo(event_search)) as rank, room_id, event_id
+ FROM event_search
+ WHERE value MATCH ?
+ """
args = [search_query] + args
- count_sql = (
- "SELECT room_id, count(*) as count FROM event_search"
- " WHERE value MATCH ?"
- )
- count_args = [search_term] + count_args
+ count_sql = """
+ SELECT room_id, count(*) as count FROM event_search
+ WHERE value MATCH ?
+ """
+ count_args = [search_query] + count_args
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@@ -510,7 +531,6 @@ class SearchStore(SearchBackgroundUpdateStore):
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
-
return {
"results": [
{"event": event_map[r["event_id"]], "rank": r["rank"]}
@@ -542,9 +562,6 @@ class SearchStore(SearchBackgroundUpdateStore):
Each match as a dictionary.
"""
clauses = []
-
- search_query = _parse_query(self.database_engine, search_term)
-
args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms.
@@ -576,26 +593,29 @@ class SearchStore(SearchBackgroundUpdateStore):
raise SynapseError(400, "Invalid pagination token")
clauses.append(
- "(origin_server_ts < ?"
- " OR (origin_server_ts = ? AND stream_ordering < ?))"
+ """
+ (origin_server_ts < ? OR (origin_server_ts = ? AND stream_ordering < ?))
+ """
)
args.extend([origin_server_ts, origin_server_ts, stream])
if isinstance(self.database_engine, PostgresEngine):
- sql = (
- "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank,"
- " origin_server_ts, stream_ordering, room_id, event_id"
- " FROM event_search"
- " WHERE vector @@ to_tsquery('english', ?) AND "
- )
+ search_query = search_term
+ sql = """
+ SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
+ origin_server_ts, stream_ordering, room_id, event_id
+ FROM event_search
+ WHERE vector @@ websearch_to_tsquery('english', ?) AND
+ """
args = [search_query, search_query] + args
- count_sql = (
- "SELECT room_id, count(*) as count FROM event_search"
- " WHERE vector @@ to_tsquery('english', ?) AND "
- )
+ count_sql = """
+ SELECT room_id, count(*) as count FROM event_search
+ WHERE vector @@ websearch_to_tsquery('english', ?) AND
+ """
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
+
# We use CROSS JOIN here to ensure we use the right indexes.
# https://sqlite.org/optoverview.html#crossjoin
#
@@ -604,23 +624,25 @@ class SearchStore(SearchBackgroundUpdateStore):
# in the events table to get the topological ordering. We need
# to use the indexes in this order because sqlite refuses to
# MATCH unless it uses the full text search index
- sql = (
- "SELECT rank(matchinfo) as rank, room_id, event_id,"
- " origin_server_ts, stream_ordering"
- " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo"
- " FROM event_search"
- " WHERE value MATCH ?"
- " )"
- " CROSS JOIN events USING (event_id)"
- " WHERE "
+ sql = """
+ SELECT
+ rank(matchinfo) as rank, room_id, event_id, origin_server_ts, stream_ordering
+ FROM (
+ SELECT key, event_id, matchinfo(event_search) as matchinfo
+ FROM event_search
+ WHERE value MATCH ?
)
+ CROSS JOIN events USING (event_id)
+ WHERE
+ """
+ search_query = _parse_query_for_sqlite(search_term)
args = [search_query] + args
- count_sql = (
- "SELECT room_id, count(*) as count FROM event_search"
- " WHERE value MATCH ? AND "
- )
- count_args = [search_term] + count_args
+ count_sql = """
+ SELECT room_id, count(*) as count FROM event_search
+ WHERE value MATCH ? AND
+ """
+ count_args = [search_query] + count_args
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@@ -631,17 +653,17 @@ class SearchStore(SearchBackgroundUpdateStore):
# We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database.
if isinstance(self.database_engine, PostgresEngine):
- sql += (
- " ORDER BY origin_server_ts DESC NULLS LAST,"
- " stream_ordering DESC NULLS LAST LIMIT ?"
- )
+ sql += """
+ ORDER BY origin_server_ts DESC NULLS LAST, stream_ordering DESC NULLS LAST
+ LIMIT ?
+ """
elif isinstance(self.database_engine, Sqlite3Engine):
sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
else:
raise Exception("Unrecognized database engine")
# mypy expects to append only a `str`, not an `int`
- args.append(limit) # type: ignore[arg-type]
+ args.append(limit)
results = await self.db_pool.execute(
"search_rooms", self.db_pool.cursor_to_dict, sql, *args
@@ -729,13 +751,16 @@ class SearchStore(SearchBackgroundUpdateStore):
while stop_sel in value:
stop_sel += ">"
- query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
- _to_postgres_options(
- {
- "StartSel": start_sel,
- "StopSel": stop_sel,
- "MaxFragments": "50",
- }
+ query = (
+ "SELECT ts_headline(?, websearch_to_tsquery('english', ?), %s)"
+ % (
+ _to_postgres_options(
+ {
+ "StartSel": start_sel,
+ "StopSel": stop_sel,
+ "MaxFragments": "50",
+ }
+ )
)
)
txn.execute(query, (value, search_query))
@@ -760,20 +785,127 @@ def _to_postgres_options(options_dict: JsonDict) -> str:
return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
-def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
- """Takes a plain unicode string from the user and converts it into a form
- that can be passed to database.
- We use this so that we can add prefix matching, which isn't something
- that is supported by default.
+@dataclass
+class Phrase:
+ phrase: List[str]
+
+
+class SearchToken(enum.Enum):
+ Not = enum.auto()
+ Or = enum.auto()
+ And = enum.auto()
+
+
+Token = Union[str, Phrase, SearchToken]
+TokenList = List[Token]
+
+
+def _is_stop_word(word: str) -> bool:
+ # TODO Pull these out of the dictionary:
+ # https://github.com/postgres/postgres/blob/master/src/backend/snowball/stopwords/english.stop
+ return word in {"the", "a", "you", "me", "and", "but"}
+
+
+def _tokenize_query(query: str) -> TokenList:
"""
+ Convert the user-supplied `query` into a TokenList, which can be translated into
+ some DB-specific syntax.
+
+ The following constructs are supported:
+
+ - phrase queries using "double quotes"
+ - case-insensitive `or` and `and` operators
+ - negation of a keyword via unary `-`
+ - unary hyphen to denote NOT e.g. 'include -exclude'
- # Pull out the individual words, discarding any non-word characters.
- results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+ The following differs from websearch_to_tsquery:
- if isinstance(database_engine, PostgresEngine):
- return " & ".join(result + ":*" for result in results)
- elif isinstance(database_engine, Sqlite3Engine):
- return " & ".join(result + "*" for result in results)
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
+ - Stop words are not removed.
+ - Unclosed phrases are treated differently.
+
+ """
+ tokens: TokenList = []
+
+ # Find phrases.
+ in_phrase = False
+ parts = deque(query.split('"'))
+ for i, part in enumerate(parts):
+ # The contents inside double quotes is treated as a phrase.
+ in_phrase = bool(i % 2)
+
+ # Pull out the individual words, discarding any non-word characters.
+ words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE))
+
+ # Phrases have simplified handling of words.
+ if in_phrase:
+ # Skip stop words.
+ phrase = [word for word in words if not _is_stop_word(word)]
+
+ # Consecutive words are implicitly ANDed together.
+ if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or):
+ tokens.append(SearchToken.And)
+
+ # Add the phrase.
+ tokens.append(Phrase(phrase))
+ continue
+
+ # Otherwise, not in a phrase.
+ while words:
+ word = words.popleft()
+
+ if word.startswith("-"):
+ tokens.append(SearchToken.Not)
+
+ # If there's more word, put it back to be processed again.
+ word = word[1:]
+ if word:
+ words.appendleft(word)
+ elif word.lower() == "or":
+ tokens.append(SearchToken.Or)
+ else:
+ # Skip stop words.
+ if _is_stop_word(word):
+ continue
+
+ # Consecutive words are implicitly ANDed together.
+ if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or):
+ tokens.append(SearchToken.And)
+
+ # Add the search term.
+ tokens.append(word)
+
+ return tokens
+
+
+def _tokens_to_sqlite_match_query(tokens: TokenList) -> str:
+ """
+ Convert the list of tokens to a string suitable for passing to sqlite's MATCH.
+ Assume sqlite was compiled with enhanced query syntax.
+
+ Ref: https://www.sqlite.org/fts3.html#full_text_index_queries
+ """
+ match_query = []
+ for token in tokens:
+ if isinstance(token, str):
+ match_query.append(token)
+ elif isinstance(token, Phrase):
+ match_query.append('"' + " ".join(token.phrase) + '"')
+ elif token == SearchToken.Not:
+ # TODO: SQLite treats NOT as a *binary* operator. Hopefully a search
+ # term has already been added before this.
+ match_query.append(" NOT ")
+ elif token == SearchToken.Or:
+ match_query.append(" OR ")
+ elif token == SearchToken.And:
+ match_query.append(" AND ")
+ else:
+ raise ValueError(f"unknown token {token}")
+
+ return "".join(match_query)
+
+
+def _parse_query_for_sqlite(search_term: str) -> str:
+ """Takes a plain unicode string from the user and converts it into a form
+ that can be passed to sqllite's matchinfo().
+ """
+ return _tokens_to_sqlite_match_query(_tokenize_query(search_term))
|