diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d9df437e51..e4162f846b 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -28,6 +28,7 @@ from typing import (
cast,
)
+from canonicaljson import encode_canonical_json
from typing_extensions import Literal
from synapse.api.constants import EduTypes
@@ -1188,8 +1189,42 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
def _store_dehydrated_device_txn(
- self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ device_data: str,
+ time: int,
+ keys: Optional[JsonDict] = None,
) -> Optional[str]:
+ # TODO: make keys non-optional once support for msc2697 is dropped
+ if keys:
+ device_keys = keys.get("device_keys", None)
+ if device_keys:
+ # Type ignore - this function is defined on EndToEndKeyStore which we do
+ # have access to due to hs.get_datastore() "magic"
+ self._set_e2e_device_keys_txn( # type: ignore[attr-defined]
+ txn, user_id, device_id, time, device_keys
+ )
+
+ one_time_keys = keys.get("one_time_keys", None)
+ if one_time_keys:
+ key_list = []
+ for key_id, key_obj in one_time_keys.items():
+ algorithm, key_id = key_id.split(":")
+ key_list.append(
+ (
+ algorithm,
+ key_id,
+ encode_canonical_json(key_obj).decode("ascii"),
+ )
+ )
+ self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list)
+
+ fallback_keys = keys.get("fallback_keys", None)
+ if fallback_keys:
+ self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys)
+
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="dehydrated_devices",
@@ -1203,10 +1238,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
keyvalues={"user_id": user_id},
values={"device_id": device_id, "device_data": device_data},
)
+
return old_device_id
async def store_dehydrated_device(
- self, user_id: str, device_id: str, device_data: JsonDict
+ self,
+ user_id: str,
+ device_id: str,
+ device_data: JsonDict,
+ time_now: int,
+ keys: Optional[dict] = None,
) -> Optional[str]:
"""Store a dehydrated device for a user.
@@ -1214,15 +1255,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_id: the user that we are storing the device for
device_id: the ID of the dehydrated device
device_data: the dehydrated device information
+ time_now: current time at the request in milliseconds
+ keys: keys for the dehydrated device
+
Returns:
device id of the user's previous dehydrated device, if any
"""
+
return await self.db_pool.runInteraction(
"store_dehydrated_device_txn",
self._store_dehydrated_device_txn,
user_id,
device_id,
json_encoder.encode(device_data),
+ time_now,
+ keys,
)
async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 91ae9c457d..b49dea577c 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -522,36 +522,57 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
"""
- def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
- set_tag("user_id", user_id)
- set_tag("device_id", device_id)
- set_tag("new_keys", str(new_keys))
- # We are protected from race between lookup and insertion due to
- # a unique constraint. If there is a race of two calls to
- # `add_e2e_one_time_keys` then they'll conflict and we will only
- # insert one set.
- self.db_pool.simple_insert_many_txn(
- txn,
- table="e2e_one_time_keys_json",
- keys=(
- "user_id",
- "device_id",
- "algorithm",
- "key_id",
- "ts_added_ms",
- "key_json",
- ),
- values=[
- (user_id, device_id, algorithm, key_id, time_now, json_bytes)
- for algorithm, key_id, json_bytes in new_keys
- ],
- )
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
await self.db_pool.runInteraction(
- "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
+ "add_e2e_one_time_keys_insert",
+ self._add_e2e_one_time_keys_txn,
+ user_id,
+ device_id,
+ time_now,
+ new_keys,
+ )
+
+ def _add_e2e_one_time_keys_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ time_now: int,
+ new_keys: Iterable[Tuple[str, str, str]],
+ ) -> None:
+ """Insert some new one time keys for a device. Errors if any of the keys already exist.
+
+ Args:
+ user_id: id of user to get keys for
+ device_id: id of device to get keys for
+ time_now: insertion time to record (ms since epoch)
+ new_keys: keys to add - each a tuple of (algorithm, key_id, key json) - note
+ that the key JSON must be in canonical JSON form
+ """
+ set_tag("user_id", user_id)
+ set_tag("device_id", device_id)
+ set_tag("new_keys", str(new_keys))
+ # We are protected from race between lookup and insertion due to
+ # a unique constraint. If there is a race of two calls to
+ # `add_e2e_one_time_keys` then they'll conflict and we will only
+ # insert one set.
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ keys=(
+ "user_id",
+ "device_id",
+ "algorithm",
+ "key_id",
+ "ts_added_ms",
+ "key_json",
+ ),
+ values=[
+ (user_id, device_id, algorithm, key_id, time_now, json_bytes)
+ for algorithm, key_id, json_bytes in new_keys
+ ],
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
@cached(max_entries=10000)
@@ -723,6 +744,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
device_id: str,
fallback_keys: JsonDict,
) -> None:
+ """Set the user's e2e fallback keys.
+
+ Args:
+ user_id: the user whose keys are being set
+ device_id: the device whose keys are being set
+ fallback_keys: the keys to set. This is a map from key ID (which is
+ of the form "algorithm:id") to key data.
+ """
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded
@@ -1304,42 +1333,69 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
+
+ Args:
+ user_id: user_id of the user to store keys for
+ device_id: device_id of the device to store keys for
+ time_now: time at the request to store the keys
+ device_keys: the keys to store
"""
- def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
- set_tag("user_id", user_id)
- set_tag("device_id", device_id)
- set_tag("time_now", time_now)
- set_tag("device_keys", str(device_keys))
+ return await self.db_pool.runInteraction(
+ "set_e2e_device_keys",
+ self._set_e2e_device_keys_txn,
+ user_id,
+ device_id,
+ time_now,
+ device_keys,
+ )
- old_key_json = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- retcol="key_json",
- allow_none=True,
- )
+ def _set_e2e_device_keys_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ time_now: int,
+ device_keys: JsonDict,
+ ) -> bool:
+ """Stores device keys for a device. Returns whether there was a change
+ or the keys were already in the database.
- # In py3 we need old_key_json to match new_key_json type. The DB
- # returns unicode while encode_canonical_json returns bytes.
- new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+ Args:
+ user_id: user_id of the user to store keys for
+ device_id: device_id of the device to store keys for
+ time_now: time at the request to store the keys
+ device_keys: the keys to store
+ """
+ set_tag("user_id", user_id)
+ set_tag("device_id", device_id)
+ set_tag("time_now", time_now)
+ set_tag("device_keys", str(device_keys))
+
+ old_key_json = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcol="key_json",
+ allow_none=True,
+ )
- if old_key_json == new_key_json:
- log_kv({"Message": "Device key already stored."})
- return False
+ # In py3 we need old_key_json to match new_key_json type. The DB
+ # returns unicode while encode_canonical_json returns bytes.
+ new_key_json = encode_canonical_json(device_keys).decode("utf-8")
- self.db_pool.simple_upsert_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- values={"ts_added_ms": time_now, "key_json": new_key_json},
- )
- log_kv({"message": "Device keys stored."})
- return True
+ if old_key_json == new_key_json:
+ log_kv({"Message": "Device key already stored."})
+ return False
- return await self.db_pool.runInteraction(
- "set_e2e_device_keys", _set_e2e_device_keys_txn
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ values={"ts_added_ms": time_now, "key_json": new_key_json},
)
+ log_kv({"message": "Device keys stored."})
+ return True
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index fff417f9e3..047de6283a 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Tuple, Union, cast
+from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json
-from typing_extensions import TYPE_CHECKING
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 1666e3c43b..a3b4744855 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,14 +16,13 @@
import itertools
import json
import logging
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
+from typing import Dict, Iterable, Mapping, Optional, Tuple
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction
-from synapse.storage.keys import FetchKeyResult
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -34,7 +33,7 @@ logger = logging.getLogger(__name__)
db_binary_type = memoryview
-class KeyStore(SQLBaseStore):
+class KeyStore(CacheInvalidationWorkerStore):
"""Persistence for signature verification keys"""
@cached()
@@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore):
# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
- self._get_server_keys_json.invalidate((((server_name, key_id),)))
+ await self.invalidate_cache_and_stream(
+ "_get_server_keys_json", ((server_name, key_id),)
+ )
+ await self.invalidate_cache_and_stream(
+ "get_server_key_json_for_remote", (server_name, key_id)
+ )
@cached()
def _get_server_keys_json(
@@ -253,47 +257,87 @@ class KeyStore(SQLBaseStore):
return await self.db_pool.runInteraction("get_server_keys_json", _txn)
- async def get_server_keys_json_for_remote(
- self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
- ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
- """Retrieve the key json for a list of server_keys and key ids.
- If no keys are found for a given server, key_id and source then
- that server, key_id, and source triplet entry will be an empty list.
- The JSON is returned as a byte array so that it can be efficiently
- used in an HTTP response.
+ @cached()
+ def get_server_key_json_for_remote(
+ self,
+ server_name: str,
+ key_id: str,
+ ) -> Optional[FetchKeyResultForRemote]:
+ raise NotImplementedError()
- Args:
- server_keys: List of (server_name, key_id, source) triplets.
+ @cachedList(
+ cached_method_name="get_server_key_json_for_remote", list_name="key_ids"
+ )
+ async def get_server_keys_json_for_remote(
+ self, server_name: str, key_ids: Iterable[str]
+ ) -> Dict[str, Optional[FetchKeyResultForRemote]]:
+ """Fetch the cached keys for the given server/key IDs.
- Returns:
- A mapping from (server_name, key_id, source) triplets to a list of dicts
+ If we have multiple entries for a given key ID, returns the most recent.
"""
+ rows = await self.db_pool.simple_select_many_batch(
+ table="server_keys_json",
+ column="key_id",
+ iterable=key_ids,
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
+ )
- def _get_server_keys_json_txn(
- txn: LoggingTransaction,
- ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
- results = {}
- for server_name, key_id, from_server in server_keys:
- keyvalues = {"server_name": server_name}
- if key_id is not None:
- keyvalues["key_id"] = key_id
- if from_server is not None:
- keyvalues["from_server"] = from_server
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "server_keys_json",
- keyvalues=keyvalues,
- retcols=(
- "key_id",
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "key_json",
- ),
- )
- results[(server_name, key_id, from_server)] = rows
- return results
+ if not rows:
+ return {}
+
+ # We sort the rows so that the most recently added entry is picked up.
+ rows.sort(key=lambda r: r["ts_added_ms"])
+
+ return {
+ row["key_id"]: FetchKeyResultForRemote(
+ # Cast to bytes since postgresql returns a memoryview.
+ key_json=bytes(row["key_json"]),
+ valid_until_ts=row["ts_valid_until_ms"],
+ added_ts=row["ts_added_ms"],
+ )
+ for row in rows
+ }
- return await self.db_pool.runInteraction(
- "get_server_keys_json", _get_server_keys_json_txn
+ async def get_all_server_keys_json_for_remote(
+ self,
+ server_name: str,
+ ) -> Dict[str, FetchKeyResultForRemote]:
+ """Fetch the cached keys for the given server.
+
+ If we have multiple entries for a given key ID, returns the most recent.
+ """
+ rows = await self.db_pool.simple_select_list(
+ table="server_keys_json",
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
)
+
+ if not rows:
+ return {}
+
+ rows.sort(key=lambda r: r["ts_added_ms"])
+
+ return {
+ row["key_id"]: FetchKeyResultForRemote(
+ # Cast to bytes since postgresql returns a memoryview.
+ key_json=bytes(row["key_json"]),
+ valid_until_ts=row["ts_valid_until_ms"],
+ added_ts=row["ts_added_ms"],
+ )
+ for row in rows
+ }
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index 1680bf6168..54d40e7a3a 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -26,7 +26,6 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
-from synapse.storage.engines import PostgresEngine
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -96,6 +95,10 @@ class LockStore(SQLBaseStore):
self._acquiring_locks: Set[Tuple[str, str]] = set()
+ self._clock.looping_call(
+ self._reap_stale_read_write_locks, _LOCK_TIMEOUT_MS / 10.0
+ )
+
@wrap_as_background_process("LockStore._on_shutdown")
async def _on_shutdown(self) -> None:
"""Called when the server is shutting down"""
@@ -216,6 +219,7 @@ class LockStore(SQLBaseStore):
lock_name,
lock_key,
write,
+ db_autocommit=True,
)
except self.database_engine.module.IntegrityError:
return None
@@ -233,61 +237,22 @@ class LockStore(SQLBaseStore):
# `worker_read_write_locks` and seeing if that fails any
# constraints. If it doesn't then we have acquired the lock,
# otherwise we haven't.
- #
- # Before that though we clear the table of any stale locks.
now = self._clock.time_msec()
token = random_string(6)
- delete_sql = """
- DELETE FROM worker_read_write_locks
- WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
- """
-
- insert_sql = """
- INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
- VALUES (?, ?, ?, ?, ?, ?)
- """
-
- if isinstance(self.database_engine, PostgresEngine):
- # For Postgres we can send these queries at the same time.
- txn.execute(
- delete_sql + ";" + insert_sql,
- (
- # DELETE args
- now - _LOCK_TIMEOUT_MS,
- lock_name,
- lock_key,
- # UPSERT args
- lock_name,
- lock_key,
- write,
- self._instance_name,
- token,
- now,
- ),
- )
- else:
- # For SQLite these need to be two queries.
- txn.execute(
- delete_sql,
- (
- now - _LOCK_TIMEOUT_MS,
- lock_name,
- lock_key,
- ),
- )
- txn.execute(
- insert_sql,
- (
- lock_name,
- lock_key,
- write,
- self._instance_name,
- token,
- now,
- ),
- )
+ self.db_pool.simple_insert_txn(
+ txn,
+ table="worker_read_write_locks",
+ values={
+ "lock_name": lock_name,
+ "lock_key": lock_key,
+ "write_lock": write,
+ "instance_name": self._instance_name,
+ "token": token,
+ "last_renewed_ts": now,
+ },
+ )
lock = Lock(
self._reactor,
@@ -351,6 +316,24 @@ class LockStore(SQLBaseStore):
return locks
+ @wrap_as_background_process("_reap_stale_read_write_locks")
+ async def _reap_stale_read_write_locks(self) -> None:
+ delete_sql = """
+ DELETE FROM worker_read_write_locks
+ WHERE last_renewed_ts < ?
+ """
+
+ def reap_stale_read_write_locks_txn(txn: LoggingTransaction) -> None:
+ txn.execute(delete_sql, (self._clock.time_msec() - _LOCK_TIMEOUT_MS,))
+ if txn.rowcount:
+ logger.info("Reaped %d stale locks", txn.rowcount)
+
+ await self.db_pool.runInteraction(
+ "_reap_stale_read_write_locks",
+ reap_stale_read_write_locks_txn,
+ db_autocommit=True,
+ )
+
class Lock:
"""An async context manager that manages an acquired lock, ensuring it is
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c582cf0573..d3a01d526f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -205,7 +205,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
name, password_hash, is_guest, admin, consent_version, consent_ts,
consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
- COALESCE(approved, TRUE) AS approved
+ COALESCE(approved, TRUE) AS approved,
+ COALESCE(locked, FALSE) AS locked
FROM users
WHERE name = ?
""",
@@ -230,10 +231,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# want to make sure we're returning the right type of data.
# Note: when adding a column name to this list, be wary of NULLable columns,
# since NULL values will be turned into False.
- boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"]
+ boolean_columns = [
+ "admin",
+ "deactivated",
+ "shadow_banned",
+ "approved",
+ "locked",
+ ]
for column in boolean_columns:
- if not isinstance(row[column], bool):
- row[column] = bool(row[column])
+ row[column] = bool(row[column])
return row
@@ -1116,6 +1122,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Convert the integer into a boolean.
return res == 1
+ @cached()
+ async def get_user_locked_status(self, user_id: str) -> bool:
+ """Retrieve the value for the `locked` property for the provided user.
+
+ Args:
+ user_id: The ID of the user to retrieve the status for.
+
+ Returns:
+ True if the user was locked, false if the user is still active.
+ """
+
+ res = await self.db_pool.simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="locked",
+ desc="get_user_locked_status",
+ )
+
+ # Convert the potential integer into a boolean.
+ return bool(res)
+
async def get_threepid_validation_session(
self,
medium: Optional[str],
@@ -2111,6 +2138,33 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
+ async def set_user_locked_status(self, user_id: str, locked: bool) -> None:
+ """Set the `locked` property for the provided user to the provided value.
+
+ Args:
+ user_id: The ID of the user to set the status for.
+ locked: The value to set for `locked`.
+ """
+
+ await self.db_pool.runInteraction(
+ "set_user_locked_status",
+ self.set_user_locked_status_txn,
+ user_id,
+ locked,
+ )
+
+ def set_user_locked_status_txn(
+ self, txn: LoggingTransaction, user_id: str, locked: bool
+ ) -> None:
+ self.db_pool.simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"locked": locked},
+ )
+ self._invalidate_cache_and_stream(txn, self.get_user_locked_status, (user_id,))
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
def update_user_approval_status_txn(
self, txn: LoggingTransaction, user_id: str, approved: bool
) -> None:
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index f34b7ce8f4..6298f0984d 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -19,6 +19,7 @@ from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
+ Counter,
Dict,
Iterable,
List,
@@ -28,8 +29,6 @@ from typing import (
cast,
)
-from typing_extensions import Counter
-
from twisted.internet.defer import DeferredLock
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 2a136f2ff6..f0dc31fee6 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -995,7 +995,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
async def search_user_dir(
- self, user_id: str, search_term: str, limit: int
+ self,
+ user_id: str,
+ search_term: str,
+ limit: int,
+ show_locked_users: bool = False,
) -> SearchResult:
"""Searches for users in directory
@@ -1029,6 +1033,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
"""
+ if not show_locked_users:
+ where_clause += " AND (u.locked IS NULL OR u.locked = FALSE)"
+
# We allow manipulating the ranking algorithm by injecting statements
# based on config options.
additional_ordering_statements = []
@@ -1060,6 +1067,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
SELECT d.user_id AS user_id, display_name, avatar_url
FROM matching_users as t
INNER JOIN user_directory AS d USING (user_id)
+ LEFT JOIN users AS u ON t.user_id = u.name
WHERE
%(where_clause)s
ORDER BY
@@ -1115,6 +1123,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
SELECT d.user_id AS user_id, display_name, avatar_url
FROM user_directory_search as t
INNER JOIN user_directory AS d USING (user_id)
+ LEFT JOIN users AS u ON t.user_id = u.name
WHERE
%(where_clause)s
AND value MATCH ?
|