diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 29bf47befc..df7f8a43b7 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -236,7 +236,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
COALESCE(approved, TRUE) AS approved,
- COALESCE(locked, FALSE) AS locked
+ COALESCE(locked, FALSE) AS locked,
+ suspended
FROM users
WHERE name = ?
""",
@@ -261,6 +262,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
shadow_banned,
approved,
locked,
+ suspended,
) = row
return UserInfo(
@@ -277,6 +279,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_type=user_type,
approved=bool(approved),
locked=bool(locked),
+ suspended=bool(suspended),
)
return await self.db_pool.runInteraction(
@@ -1180,6 +1183,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Convert the potential integer into a boolean.
return bool(res)
+ @cached()
+ async def get_user_suspended_status(self, user_id: str) -> bool:
+ """
+ Determine whether the user's account is suspended.
+ Args:
+ user_id: The user ID of the user in question
+ Returns:
+ True if the user's account is suspended, false if it is not suspended or
+ if the user ID cannot be found.
+ """
+
+ res = await self.db_pool.simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="suspended",
+ allow_none=True,
+ desc="get_user_suspended",
+ )
+
+ return bool(res)
+
async def get_threepid_validation_session(
self,
medium: Optional[str],
@@ -2213,6 +2237,35 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
+ async def set_user_suspended_status(self, user_id: str, suspended: bool) -> None:
+ """
+ Set whether the user's account is suspended in the `users` table.
+
+ Args:
+ user_id: The user ID of the user in question
+ suspended: True if the user is suspended, false if not
+ """
+ await self.db_pool.runInteraction(
+ "set_user_suspended_status",
+ self.set_user_suspended_status_txn,
+ user_id,
+ suspended,
+ )
+
+ def set_user_suspended_status_txn(
+ self, txn: LoggingTransaction, user_id: str, suspended: bool
+ ) -> None:
+ self.db_pool.simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"suspended": suspended},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_suspended_status, (user_id,)
+ )
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
async def set_user_locked_status(self, user_id: str, locked: bool) -> None:
"""Set the `locked` property for the provided user to the provided value.
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 4a0afb50ac..20fcfd3122 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -470,6 +470,8 @@ class SearchStore(SearchBackgroundUpdateStore):
count_args = args
count_clauses = clauses
+ sqlite_highlights: List[str] = []
+
if isinstance(self.database_engine, PostgresEngine):
search_query = search_term
sql = """
@@ -486,7 +488,7 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
- search_query = _parse_query_for_sqlite(search_term)
+ search_query, sqlite_highlights = _parse_query_for_sqlite(search_term)
sql = """
SELECT rank(matchinfo(event_search)) as rank, room_id, event_id
@@ -531,9 +533,11 @@ class SearchStore(SearchBackgroundUpdateStore):
event_map = {ev.event_id: ev for ev in events}
- highlights = None
+ highlights: Collection[str] = []
if isinstance(self.database_engine, PostgresEngine):
highlights = await self._find_highlights_in_postgres(search_query, events)
+ else:
+ highlights = sqlite_highlights
count_sql += " GROUP BY room_id"
@@ -597,6 +601,8 @@ class SearchStore(SearchBackgroundUpdateStore):
count_args = list(args)
count_clauses = list(clauses)
+ sqlite_highlights: List[str] = []
+
if pagination_token:
try:
origin_server_ts_str, stream_str = pagination_token.split(",")
@@ -647,7 +653,7 @@ class SearchStore(SearchBackgroundUpdateStore):
CROSS JOIN events USING (event_id)
WHERE
"""
- search_query = _parse_query_for_sqlite(search_term)
+ search_query, sqlite_highlights = _parse_query_for_sqlite(search_term)
args = [search_query] + args
count_sql = """
@@ -694,9 +700,11 @@ class SearchStore(SearchBackgroundUpdateStore):
event_map = {ev.event_id: ev for ev in events}
- highlights = None
+ highlights: Collection[str] = []
if isinstance(self.database_engine, PostgresEngine):
highlights = await self._find_highlights_in_postgres(search_query, events)
+ else:
+ highlights = sqlite_highlights
count_sql += " GROUP BY room_id"
@@ -892,19 +900,25 @@ def _tokenize_query(query: str) -> TokenList:
return tokens
-def _tokens_to_sqlite_match_query(tokens: TokenList) -> str:
+def _tokens_to_sqlite_match_query(tokens: TokenList) -> Tuple[str, List[str]]:
"""
Convert the list of tokens to a string suitable for passing to sqlite's MATCH.
Assume sqlite was compiled with enhanced query syntax.
+ Returns the sqlite-formatted query string and the tokenized search terms
+ that can be used as highlights.
+
Ref: https://www.sqlite.org/fts3.html#full_text_index_queries
"""
match_query = []
+ highlights = []
for token in tokens:
if isinstance(token, str):
match_query.append(token)
+ highlights.append(token)
elif isinstance(token, Phrase):
match_query.append('"' + " ".join(token.phrase) + '"')
+ highlights.append(" ".join(token.phrase))
elif token == SearchToken.Not:
# TODO: SQLite treats NOT as a *binary* operator. Hopefully a search
# term has already been added before this.
@@ -916,11 +930,14 @@ def _tokens_to_sqlite_match_query(tokens: TokenList) -> str:
else:
raise ValueError(f"unknown token {token}")
- return "".join(match_query)
+ return "".join(match_query), highlights
-def _parse_query_for_sqlite(search_term: str) -> str:
+def _parse_query_for_sqlite(search_term: str) -> Tuple[str, List[str]]:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to sqllite's matchinfo().
+
+ Returns the converted query string and the tokenized search terms
+ that can be used as highlights.
"""
return _tokens_to_sqlite_match_query(_tokenize_query(search_term))
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 08e0241f68..770802483c 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -660,6 +660,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
limit=limit,
retcols=("room_id", "stream_ordering"),
order_direction=order,
+ keyvalues={"destination": destination},
),
)
return rooms, count
|