diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index a62b4abd4e..cfaedf5e0c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -201,7 +201,7 @@ class DataStore(
name: Optional[str] = None,
guests: bool = True,
deactivated: bool = False,
- order_by: str = UserSortOrder.USER_ID.value,
+ order_by: str = UserSortOrder.NAME.value,
direction: str = "f",
approved: bool = True,
) -> Tuple[List[JsonDict], int]:
@@ -261,6 +261,7 @@ class DataStore(
sql_base = f"""
FROM users as u
LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
+ LEFT JOIN erased_users AS eu ON u.name = eu.user_id
{where_clause}
"""
sql = "SELECT COUNT(*) as total_users " + sql_base
@@ -269,7 +270,8 @@ class DataStore(
sql = f"""
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
- displayname, avatar_url, creation_ts * 1000 as creation_ts, approved
+ displayname, avatar_url, creation_ts * 1000 as creation_ts, approved,
+ eu.user_id is not null as erased
{sql_base}
ORDER BY {order_by_column} {order}, u.name ASC
LIMIT ? OFFSET ?
@@ -277,6 +279,13 @@ class DataStore(
args += [limit, start]
txn.execute(sql, args)
users = self.db_pool.cursor_to_dict(txn)
+
+ # some of those boolean values are returned as integers when we're on SQLite
+ columns_to_boolify = ["erased"]
+ for user in users:
+ for column in columns_to_boolify:
+ user[column] = bool(user[column])
+
return users, count
return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 64b70a7b28..63046c0527 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -157,10 +157,23 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
app_service: "ApplicationService",
cache_context: _CacheContext,
) -> List[str]:
- users_in_room = await self.get_users_in_room(
+ """
+ Get all users in a room that the appservice controls.
+
+ Args:
+ room_id: The room to check in.
+ app_service: The application service to check interest/control against
+
+ Returns:
+ List of user IDs that the appservice controls.
+ """
+ # We can use `get_local_users_in_room(...)` here because an application service
+ # can only be interested in local users of the server it's on (ignore any remote
+ # users that might match the user namespace regex).
+ local_users_in_room = await self.get_local_users_in_room(
room_id, on_invalidate=cache_context.invalidate
)
- return list(filter(app_service.is_interested_in_user, users_in_room))
+ return list(filter(app_service.is_interested_in_user, local_users_in_room))
class ApplicationServiceStore(ApplicationServiceWorkerStore):
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 830b076a32..979dd4e17e 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -274,6 +274,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
destination, int(from_stream_id)
)
if not has_changed:
+ # debugging for https://github.com/matrix-org/synapse/issues/14251
+ issue_8631_logger.debug(
+ "%s: no change between %i and %i",
+ destination,
+ from_stream_id,
+ now_stream_id,
+ )
return now_stream_id, []
updates = await self.db_pool.runInteraction(
@@ -1848,7 +1855,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self,
txn: LoggingTransaction,
user_id: str,
- device_ids: Iterable[str],
+ device_id: str,
hosts: Collection[str],
stream_ids: List[int],
context: Optional[Dict[str, str]],
@@ -1864,6 +1871,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
stream_id_iterator = iter(stream_ids)
encoded_context = json_encoder.encode(context)
+ mark_sent = not self.hs.is_mine_id(user_id)
+
+ values = [
+ (
+ destination,
+ next(stream_id_iterator),
+ user_id,
+ device_id,
+ mark_sent,
+ now,
+ encoded_context if whitelisted_homeserver(destination) else "{}",
+ )
+ for destination in hosts
+ ]
+
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
@@ -1876,23 +1898,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"ts",
"opentracing_context",
),
- values=[
- (
- destination,
- next(stream_id_iterator),
- user_id,
- device_id,
- not self.hs.is_mine_id(
- user_id
- ), # We only need to send out update for *our* users
- now,
- encoded_context if whitelisted_homeserver(destination) else "{}",
- )
- for destination in hosts
- for device_id in device_ids
- ],
+ values=values,
)
+ # debugging for https://github.com/matrix-org/synapse/issues/14251
+ if issue_8631_logger.isEnabledFor(logging.DEBUG):
+ issue_8631_logger.debug(
+ "Recorded outbound pokes for %s:%s with device stream ids %s",
+ user_id,
+ device_id,
+ {
+ stream_id: destination
+ for (destination, stream_id, _, _, _, _, _) in values
+ },
+ )
+
def _add_device_outbound_room_poke_txn(
self,
txn: LoggingTransaction,
@@ -1997,7 +2017,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self._add_device_outbound_poke_to_stream_txn(
txn,
user_id=user_id,
- device_ids=[device_id],
+ device_id=device_id,
hosts=hosts,
stream_ids=stream_ids,
context=context,
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 51416b2236..b6c15f29f8 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -29,6 +29,7 @@ from typing import (
)
from synapse.api.errors import StoreError
+from synapse.config.homeserver import ExperimentalConfig
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@@ -62,7 +63,9 @@ logger = logging.getLogger(__name__)
def _load_rules(
- rawrules: List[JsonDict], enabled_map: Dict[str, bool]
+ rawrules: List[JsonDict],
+ enabled_map: Dict[str, bool],
+ experimental_config: ExperimentalConfig,
) -> FilteredPushRules:
"""Take the DB rows returned from the DB and convert them into a full
`FilteredPushRules` object.
@@ -80,7 +83,9 @@ def _load_rules(
push_rules = PushRules(ruleslist)
- filtered_rules = FilteredPushRules(push_rules, enabled_map)
+ filtered_rules = FilteredPushRules(
+ push_rules, enabled_map, msc3664_enabled=experimental_config.msc3664_enabled
+ )
return filtered_rules
@@ -160,7 +165,7 @@ class PushRulesWorkerStore(
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
- return _load_rules(rows, enabled_map)
+ return _load_rules(rows, enabled_map, self.hs.config.experimental)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
@@ -219,7 +224,9 @@ class PushRulesWorkerStore(
results: Dict[str, FilteredPushRules] = {}
for user_id, rules in raw_rules.items():
- results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
+ results[user_id] = _load_rules(
+ rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental
+ )
return results
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 2996d6bb4d..0255295317 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -21,7 +21,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import attr
from synapse.api.constants import UserTypes
-from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
+from synapse.api.errors import (
+ Codes,
+ NotFoundError,
+ StoreError,
+ SynapseError,
+ ThreepidValidationError,
+)
from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import (
@@ -50,6 +56,14 @@ class ExternalIDReuseException(Exception):
because this external id is given to an other user."""
+class LoginTokenExpired(Exception):
+ """Exception if the login token sent expired"""
+
+
+class LoginTokenReused(Exception):
+ """Exception if the login token sent was already used"""
+
+
@attr.s(frozen=True, slots=True, auto_attribs=True)
class TokenLookupResult:
"""Result of looking up an access token.
@@ -115,6 +129,20 @@ class RefreshTokenLookupResult:
If None, the session can be refreshed indefinitely."""
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class LoginTokenLookupResult:
+ """Result of looking up a login token."""
+
+ user_id: str
+ """The user this token belongs to."""
+
+ auth_provider_id: Optional[str]
+ """The SSO Identity Provider that the user authenticated with, to get this token."""
+
+ auth_provider_session_id: Optional[str]
+ """The session ID advertised by the SSO Identity Provider."""
+
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -1789,6 +1817,109 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"replace_refresh_token", _replace_refresh_token_txn
)
+ async def add_login_token_to_user(
+ self,
+ user_id: str,
+ token: str,
+ expiry_ts: int,
+ auth_provider_id: Optional[str],
+ auth_provider_session_id: Optional[str],
+ ) -> None:
+ """Adds a short-term login token for the given user.
+
+ Args:
+ user_id: The user ID.
+ token: The new login token to add.
+ expiry_ts (milliseconds since the epoch): Time after which the login token
+ cannot be used.
+ auth_provider_id: The SSO Identity Provider that the user authenticated with
+ to get this token, if any
+ auth_provider_session_id: The session ID advertised by the SSO Identity
+ Provider, if any.
+ """
+ await self.db_pool.simple_insert(
+ "login_tokens",
+ {
+ "token": token,
+ "user_id": user_id,
+ "expiry_ts": expiry_ts,
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
+ },
+ desc="add_login_token_to_user",
+ )
+
+ def _consume_login_token(
+ self,
+ txn: LoggingTransaction,
+ token: str,
+ ts: int,
+ ) -> LoginTokenLookupResult:
+ values = self.db_pool.simple_select_one_txn(
+ txn,
+ "login_tokens",
+ keyvalues={"token": token},
+ retcols=(
+ "user_id",
+ "expiry_ts",
+ "used_ts",
+ "auth_provider_id",
+ "auth_provider_session_id",
+ ),
+ allow_none=True,
+ )
+
+ if values is None:
+ raise NotFoundError()
+
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "login_tokens",
+ keyvalues={"token": token},
+ updatevalues={"used_ts": ts},
+ )
+ user_id = values["user_id"]
+ expiry_ts = values["expiry_ts"]
+ used_ts = values["used_ts"]
+ auth_provider_id = values["auth_provider_id"]
+ auth_provider_session_id = values["auth_provider_session_id"]
+
+ # Token was already used
+ if used_ts is not None:
+ raise LoginTokenReused()
+
+ # Token expired
+ if ts > int(expiry_ts):
+ raise LoginTokenExpired()
+
+ return LoginTokenLookupResult(
+ user_id=user_id,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
+ )
+
+ async def consume_login_token(self, token: str) -> LoginTokenLookupResult:
+ """Lookup a login token and consume it.
+
+ Args:
+ token: The login token.
+
+ Returns:
+ The data stored with that token, including the `user_id`. Returns `None` if
+ the token does not exist or if it expired.
+
+ Raises:
+ NotFound if the login token was not found in database
+ LoginTokenExpired if the login token expired
+ LoginTokenReused if the login token was already used
+ """
+ return await self.db_pool.runInteraction(
+ "consume_login_token",
+ self._consume_login_token,
+ token,
+ self._clock.time_msec(),
+ )
+
@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
@@ -2019,6 +2150,12 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
and hs.config.experimental.msc3866.require_approval_for_new_accounts
)
+ # Create a background job for removing expired login tokens
+ if hs.config.worker.run_background_tasks:
+ self._clock.looping_call(
+ self._delete_expired_login_tokens, THIRTY_MINUTES_IN_MS
+ )
+
async def add_access_token_to_user(
self,
user_id: str,
@@ -2617,6 +2754,23 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
approved,
)
+ @wrap_as_background_process("delete_expired_login_tokens")
+ async def _delete_expired_login_tokens(self) -> None:
+ """Remove login tokens with expiry dates that have passed."""
+
+ def _delete_expired_login_tokens_txn(txn: LoggingTransaction, ts: int) -> None:
+ sql = "DELETE FROM login_tokens WHERE expiry_ts <= ?"
+ txn.execute(sql, (ts,))
+
+ # We keep the expired tokens for an extra 5 minutes so we can measure how many
+ # times a token is being used after its expiry
+ now = self._clock.time_msec()
+ await self.db_pool.runInteraction(
+ "delete_expired_login_tokens",
+ _delete_expired_login_tokens_txn,
+ now - (5 * 60 * 1000),
+ )
+
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 2ed6ad754f..e56a13f21e 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -152,6 +152,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
the forward extremities of those rooms will exclude most members. We may also
calculate room state incorrectly for such rooms and believe that a member is or
is not in the room when the opposite is true.
+
+ Note: If you only care about users in the room local to the homeserver, use
+ `get_local_users_in_room(...)` instead which will be more performant.
"""
return await self.db_pool.simple_select_onecol(
table="current_state_events",
@@ -707,8 +710,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# 250 users is pretty arbitrary but the data can be quite large if users
# are in many rooms.
- for user_ids in batch_iter(user_ids, 250):
- all_user_rooms.update(await self._get_rooms_for_users(user_ids))
+ for batch_user_ids in batch_iter(user_ids, 250):
+ all_user_rooms.update(await self._get_rooms_for_users(batch_user_ids))
return all_user_rooms
@@ -742,7 +745,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# user and the set of other users, and then checking if there is any
# overlap.
sql = f"""
- SELECT b.state_key
+ SELECT DISTINCT b.state_key
FROM (
SELECT room_id FROM current_state_events
WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ?
@@ -751,7 +754,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
SELECT room_id, state_key FROM current_state_events
WHERE type = 'm.room.member' AND membership = 'join' AND {clause}
) AS b using (room_id)
- LIMIT 1
"""
txn.execute(sql, (user_id, *args))
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 1b79acf955..594b935614 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -11,10 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import enum
import logging
import re
-from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
+from collections import deque
+from dataclasses import dataclass
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
import attr
@@ -27,7 +39,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
if TYPE_CHECKING:
@@ -421,8 +433,6 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
clauses = []
- search_query = _parse_query(self.database_engine, search_term)
-
args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms.
@@ -444,20 +454,24 @@ class SearchStore(SearchBackgroundUpdateStore):
count_clauses = clauses
if isinstance(self.database_engine, PostgresEngine):
+ search_query = search_term
+ tsquery_func = self.database_engine.tsquery_func
sql = (
- "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank,"
+ f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank,"
" room_id, event_id"
" FROM event_search"
- " WHERE vector @@ to_tsquery('english', ?)"
+ f" WHERE vector @@ {tsquery_func}('english', ?)"
)
args = [search_query, search_query] + args
count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
- " WHERE vector @@ to_tsquery('english', ?)"
+ f" WHERE vector @@ {tsquery_func}('english', ?)"
)
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
+ search_query = _parse_query_for_sqlite(search_term)
+
sql = (
"SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
" FROM event_search"
@@ -469,7 +483,7 @@ class SearchStore(SearchBackgroundUpdateStore):
"SELECT room_id, count(*) as count FROM event_search"
" WHERE value MATCH ?"
)
- count_args = [search_term] + count_args
+ count_args = [search_query] + count_args
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@@ -501,7 +515,9 @@ class SearchStore(SearchBackgroundUpdateStore):
highlights = None
if isinstance(self.database_engine, PostgresEngine):
- highlights = await self._find_highlights_in_postgres(search_query, events)
+ highlights = await self._find_highlights_in_postgres(
+ search_query, events, tsquery_func
+ )
count_sql += " GROUP BY room_id"
@@ -510,7 +526,6 @@ class SearchStore(SearchBackgroundUpdateStore):
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
-
return {
"results": [
{"event": event_map[r["event_id"]], "rank": r["rank"]}
@@ -542,9 +557,6 @@ class SearchStore(SearchBackgroundUpdateStore):
Each match as a dictionary.
"""
clauses = []
-
- search_query = _parse_query(self.database_engine, search_term)
-
args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms.
@@ -582,20 +594,23 @@ class SearchStore(SearchBackgroundUpdateStore):
args.extend([origin_server_ts, origin_server_ts, stream])
if isinstance(self.database_engine, PostgresEngine):
+ search_query = search_term
+ tsquery_func = self.database_engine.tsquery_func
sql = (
- "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank,"
+ f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank,"
" origin_server_ts, stream_ordering, room_id, event_id"
" FROM event_search"
- " WHERE vector @@ to_tsquery('english', ?) AND "
+ f" WHERE vector @@ {tsquery_func}('english', ?) AND "
)
args = [search_query, search_query] + args
count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
- " WHERE vector @@ to_tsquery('english', ?) AND "
+ f" WHERE vector @@ {tsquery_func}('english', ?) AND "
)
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
+
# We use CROSS JOIN here to ensure we use the right indexes.
# https://sqlite.org/optoverview.html#crossjoin
#
@@ -614,13 +629,14 @@ class SearchStore(SearchBackgroundUpdateStore):
" CROSS JOIN events USING (event_id)"
" WHERE "
)
+ search_query = _parse_query_for_sqlite(search_term)
args = [search_query] + args
count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
" WHERE value MATCH ? AND "
)
- count_args = [search_term] + count_args
+ count_args = [search_query] + count_args
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@@ -660,7 +676,9 @@ class SearchStore(SearchBackgroundUpdateStore):
highlights = None
if isinstance(self.database_engine, PostgresEngine):
- highlights = await self._find_highlights_in_postgres(search_query, events)
+ highlights = await self._find_highlights_in_postgres(
+ search_query, events, tsquery_func
+ )
count_sql += " GROUP BY room_id"
@@ -686,7 +704,7 @@ class SearchStore(SearchBackgroundUpdateStore):
}
async def _find_highlights_in_postgres(
- self, search_query: str, events: List[EventBase]
+ self, search_query: str, events: List[EventBase], tsquery_func: str
) -> Set[str]:
"""Given a list of events and a search term, return a list of words
that match from the content of the event.
@@ -697,6 +715,7 @@ class SearchStore(SearchBackgroundUpdateStore):
Args:
search_query
events: A list of events
+ tsquery_func: The tsquery_* function to use when making queries
Returns:
A set of strings.
@@ -729,7 +748,7 @@ class SearchStore(SearchBackgroundUpdateStore):
while stop_sel in value:
stop_sel += ">"
- query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
+ query = f"SELECT ts_headline(?, {tsquery_func}('english', ?), %s)" % (
_to_postgres_options(
{
"StartSel": start_sel,
@@ -760,20 +779,127 @@ def _to_postgres_options(options_dict: JsonDict) -> str:
return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
-def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
- """Takes a plain unicode string from the user and converts it into a form
- that can be passed to database.
- We use this so that we can add prefix matching, which isn't something
- that is supported by default.
+@dataclass
+class Phrase:
+ phrase: List[str]
+
+
+class SearchToken(enum.Enum):
+ Not = enum.auto()
+ Or = enum.auto()
+ And = enum.auto()
+
+
+Token = Union[str, Phrase, SearchToken]
+TokenList = List[Token]
+
+
+def _is_stop_word(word: str) -> bool:
+ # TODO Pull these out of the dictionary:
+ # https://github.com/postgres/postgres/blob/master/src/backend/snowball/stopwords/english.stop
+ return word in {"the", "a", "you", "me", "and", "but"}
+
+
+def _tokenize_query(query: str) -> TokenList:
+ """
+ Convert the user-supplied `query` into a TokenList, which can be translated into
+ some DB-specific syntax.
+
+ The following constructs are supported:
+
+ - phrase queries using "double quotes"
+ - case-insensitive `or` and `and` operators
+ - negation of a keyword via unary `-`
+ - unary hyphen to denote NOT e.g. 'include -exclude'
+
+ The following differs from websearch_to_tsquery:
+
+ - Stop words are not removed.
+ - Unclosed phrases are treated differently.
+
+ """
+ tokens: TokenList = []
+
+ # Find phrases.
+ in_phrase = False
+ parts = deque(query.split('"'))
+ for i, part in enumerate(parts):
+ # The contents inside double quotes is treated as a phrase.
+ in_phrase = bool(i % 2)
+
+ # Pull out the individual words, discarding any non-word characters.
+ words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE))
+
+ # Phrases have simplified handling of words.
+ if in_phrase:
+ # Skip stop words.
+ phrase = [word for word in words if not _is_stop_word(word)]
+
+ # Consecutive words are implicitly ANDed together.
+ if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or):
+ tokens.append(SearchToken.And)
+
+ # Add the phrase.
+ tokens.append(Phrase(phrase))
+ continue
+
+ # Otherwise, not in a phrase.
+ while words:
+ word = words.popleft()
+
+ if word.startswith("-"):
+ tokens.append(SearchToken.Not)
+
+ # If there's more word, put it back to be processed again.
+ word = word[1:]
+ if word:
+ words.appendleft(word)
+ elif word.lower() == "or":
+ tokens.append(SearchToken.Or)
+ else:
+ # Skip stop words.
+ if _is_stop_word(word):
+ continue
+
+ # Consecutive words are implicitly ANDed together.
+ if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or):
+ tokens.append(SearchToken.And)
+
+ # Add the search term.
+ tokens.append(word)
+
+ return tokens
+
+
+def _tokens_to_sqlite_match_query(tokens: TokenList) -> str:
+ """
+ Convert the list of tokens to a string suitable for passing to sqlite's MATCH.
+ Assume sqlite was compiled with enhanced query syntax.
+
+ Ref: https://www.sqlite.org/fts3.html#full_text_index_queries
"""
+ match_query = []
+ for token in tokens:
+ if isinstance(token, str):
+ match_query.append(token)
+ elif isinstance(token, Phrase):
+ match_query.append('"' + " ".join(token.phrase) + '"')
+ elif token == SearchToken.Not:
+ # TODO: SQLite treats NOT as a *binary* operator. Hopefully a search
+ # term has already been added before this.
+ match_query.append(" NOT ")
+ elif token == SearchToken.Or:
+ match_query.append(" OR ")
+ elif token == SearchToken.And:
+ match_query.append(" AND ")
+ else:
+ raise ValueError(f"unknown token {token}")
+
+ return "".join(match_query)
- # Pull out the individual words, discarding any non-word characters.
- results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
- if isinstance(database_engine, PostgresEngine):
- return " & ".join(result + ":*" for result in results)
- elif isinstance(database_engine, Sqlite3Engine):
- return " & ".join(result + "*" for result in results)
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
+def _parse_query_for_sqlite(search_term: str) -> str:
+ """Takes a plain unicode string from the user and converts it into a form
+ that can be passed to sqllite's matchinfo().
+ """
+ return _tokens_to_sqlite_match_query(_tokenize_query(search_term))
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index d8c0f64d9a..9bf74bbf59 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -170,6 +170,22 @@ class PostgresEngine(
"""Do we support the `RETURNING` clause in insert/update/delete?"""
return True
+ @property
+ def tsquery_func(self) -> str:
+ """
+ Selects a tsquery_* func to use.
+
+ Ref: https://www.postgresql.org/docs/current/textsearch-controls.html
+
+ Returns:
+ The function name.
+ """
+ # Postgres 11 added support for websearch_to_tsquery.
+ assert self._version is not None
+ if self._version >= 110000:
+ return "websearch_to_tsquery"
+ return "plainto_tsquery"
+
def is_deadlock(self, error: Exception) -> bool:
if isinstance(error, psycopg2.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index faa574dbfd..14260442b6 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -88,6 +88,10 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
db_conn.create_function("rank", 1, _rank)
db_conn.execute("PRAGMA foreign_keys = ON;")
+
+ # Enable WAL.
+ # see https://www.sqlite.org/wal.html
+ db_conn.execute("PRAGMA journal_mode = WAL;")
db_conn.commit()
def is_deadlock(self, error: Exception) -> bool:
diff --git a/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py b/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py
new file mode 100644
index 0000000000..3de0a709eb
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py
@@ -0,0 +1,62 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+
+from synapse.storage.engines import BaseDatabaseEngine, Sqlite3Engine
+from synapse.storage.types import Cursor
+
+
+def run_create(cur: Cursor, database_engine: BaseDatabaseEngine) -> None:
+ """
+ Upgrade the event_search table to use the porter tokenizer if it isn't already
+
+ Applies only for sqlite.
+ """
+ if not isinstance(database_engine, Sqlite3Engine):
+ return
+
+ # Rebuild the table event_search table with tokenize=porter configured.
+ cur.execute("DROP TABLE event_search")
+ cur.execute(
+ """
+ CREATE VIRTUAL TABLE event_search
+ USING fts4 (tokenize=porter, event_id, room_id, sender, key, value )
+ """
+ )
+
+ # Re-run the background job to re-populate the event_search table.
+ cur.execute("SELECT MIN(stream_ordering) FROM events")
+ row = cur.fetchone()
+ min_stream_id = row[0]
+
+ # If there are not any events, nothing to do.
+ if min_stream_id is None:
+ return
+
+ cur.execute("SELECT MAX(stream_ordering) FROM events")
+ row = cur.fetchone()
+ max_stream_id = row[0]
+
+ progress = {
+ "target_min_stream_id_inclusive": min_stream_id,
+ "max_stream_id_exclusive": max_stream_id + 1,
+ }
+ progress_json = json.dumps(progress)
+
+ sql = """
+ INSERT into background_updates (ordering, update_name, progress_json)
+ VALUES (?, ?, ?)
+ """
+
+ cur.execute(sql, (7310, "event_search", progress_json))
diff --git a/synapse/storage/schema/main/delta/73/10login_tokens.sql b/synapse/storage/schema/main/delta/73/10login_tokens.sql
new file mode 100644
index 0000000000..a39b7bcece
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/10login_tokens.sql
@@ -0,0 +1,35 @@
+/*
+ * Copyright 2022 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Login tokens are short-lived tokens that are used for the m.login.token
+-- login method, mainly during SSO logins
+CREATE TABLE login_tokens (
+ token TEXT PRIMARY KEY,
+ user_id TEXT NOT NULL,
+ expiry_ts BIGINT NOT NULL,
+ used_ts BIGINT,
+ auth_provider_id TEXT,
+ auth_provider_session_id TEXT
+);
+
+-- We're sometimes querying them by their session ID we got from their IDP
+CREATE INDEX login_tokens_auth_provider_idx
+ ON login_tokens (auth_provider_id, auth_provider_session_id);
+
+-- We're deleting them by their expiration time
+CREATE INDEX login_tokens_expiry_time_idx
+ ON login_tokens (expiry_ts);
+
|