diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index d0cf3460da..70ca3e09f7 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -324,7 +324,7 @@ class AccountDataWorkerStore(SQLBaseStore):
user_id, int(stream_id)
)
if not changed:
- return ({}, {})
+ return {}, {}
return await self.db_pool.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index e2d1b758bd..2da2659f41 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -60,7 +60,7 @@ def _make_exclusive_regex(
class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.services_cache = load_appservices(
- hs.hostname, hs.config.app_service_config_files
+ hs.hostname, hs.config.appservice.app_service_config_files
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 2712514145..dafba2b03f 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -555,8 +555,11 @@ class ClientIpStore(ClientIpWorkerStore):
return ret
async def get_user_ip_and_agents(
- self, user: UserID
+ self, user: UserID, since_ts: int = 0
) -> List[Dict[str, Union[str, int]]]:
+ """
+ Fetch IP/User Agent connection since a given timestamp.
+ """
user_id = user.to_string()
results = {}
@@ -568,13 +571,23 @@ class ClientIpStore(ClientIpWorkerStore):
) = key
if uid == user_id:
user_agent, _, last_seen = self._batch_row_update[key]
- results[(access_token, ip)] = (user_agent, last_seen)
+ if last_seen >= since_ts:
+ results[(access_token, ip)] = (user_agent, last_seen)
- rows = await self.db_pool.simple_select_list(
- table="user_ips",
- keyvalues={"user_id": user_id},
- retcols=["access_token", "ip", "user_agent", "last_seen"],
- desc="get_user_ip_and_agents",
+ def get_recent(txn):
+ txn.execute(
+ """
+ SELECT access_token, ip, user_agent, last_seen FROM user_ips
+ WHERE last_seen >= ? AND user_id = ?
+ ORDER BY last_seen
+ DESC
+ """,
+ (since_ts, user_id),
+ )
+ return txn.fetchall()
+
+ rows = await self.db_pool.runInteraction(
+ desc="get_user_ip_and_agents", func=get_recent
)
results.update(
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index c55508867d..3154906d45 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -136,7 +136,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
user_id, last_stream_id
)
if not has_changed:
- return ([], current_stream_id)
+ return [], current_stream_id
def get_new_messages_for_device_txn(txn):
sql = (
@@ -240,11 +240,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
if not has_changed or last_stream_id == current_stream_id:
log_kv({"message": "No new messages in stream"})
- return ([], current_stream_id)
+ return [], current_stream_id
if limit <= 0:
# This can happen if we run out of room for EDUs in the transaction.
- return ([], last_stream_id)
+ return [], last_stream_id
@trace
def get_new_messages_for_remote_destination_txn(txn):
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1f0a39eac4..a95ac34f09 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -824,6 +824,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
if otk_row is None:
return None
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
key_id, key_json = otk_row
return f"{algorithm}:{key_id}", key_json
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index d72e716b5c..4a1a2f4a6a 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1495,7 +1495,7 @@ class EventsWorkerStore(SQLBaseStore):
if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,))
- return (int(res["topological_ordering"]), int(res["stream_ordering"]))
+ return int(res["topological_ordering"]), int(res["stream_ordering"])
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index d213b26703..b76ee51a9b 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -63,7 +63,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""Generates current count of monthly active users broken down by service.
A service is typically an appservice but also includes native matrix users.
Since the `monthly_active_users` table is populated from the `user_ips` table
- `config.track_appservice_user_ips` must be set to `true` for this
+ `config.appservice.track_appservice_user_ips` must be set to `true` for this
method to return anything other than native matrix users.
Returns:
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index bccff5e5b9..3eb30944bf 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -102,15 +102,19 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
(room_id,),
)
rows = txn.fetchall()
- max_depth = max(row[1] for row in rows)
-
- if max_depth < token.topological:
- # We need to ensure we don't delete all the events from the database
- # otherwise we wouldn't be able to send any events (due to not
- # having any backwards extremities)
- raise SynapseError(
- 400, "topological_ordering is greater than forward extremeties"
- )
+ # if we already have no forwards extremities (for example because they were
+ # cleared out by the `delete_old_current_state_events` background database
+ # update), then we may as well carry on.
+ if rows:
+ max_depth = max(row[1] for row in rows)
+
+ if max_depth < token.topological:
+ # We need to ensure we don't delete all the events from the database
+ # otherwise we wouldn't be able to send any events (due to not
+ # having any backwards extremities)
+ raise SynapseError(
+ 400, "topological_ordering is greater than forward extremities"
+ )
logger.info("[purge] looking for events to delete")
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index fafadb88fc..c83089ee63 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -388,7 +388,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"get_users_expiring_soon",
select_users_txn,
self._clock.time_msec(),
- self.config.account_validity_renew_at,
+ self.config.account_validity.account_validity_renew_at,
)
async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
@@ -2015,7 +2015,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
(user_id_obj.localpart, create_profile_with_displayname),
)
- if self.hs.config.stats_enabled:
+ if self.hs.config.stats.stats_enabled:
# we create a new completed user statistics row
# we don't strictly need current_token since this user really can't
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index a4ec6bc328..ddb162a4fc 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -82,7 +82,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
if (
self.hs.config.worker.run_background_tasks
- and self.hs.config.metrics_flags.known_servers
+ and self.hs.config.metrics.metrics_flags.known_servers
):
self._known_servers_count = 1
self.hs.get_clock().looping_call(
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 1c642c753b..9eb74a81a0 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -15,12 +15,12 @@
import logging
import re
from collections import namedtuple
-from typing import Collection, List, Optional, Set
+from typing import Collection, Iterable, List, Optional, Set
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -32,14 +32,24 @@ SearchEntry = namedtuple(
)
+def _clean_value_for_search(value: str) -> str:
+ """
+ Replaces any null code points in the string with spaces as
+ Postgres and SQLite do not like the insertion of strings with
+ null code points into the full-text search tables.
+ """
+ return value.replace("\u0000", " ")
+
+
class SearchWorkerStore(SQLBaseStore):
- def store_search_entries_txn(self, txn, entries):
+ def store_search_entries_txn(
+ self, txn: LoggingTransaction, entries: Iterable[SearchEntry]
+ ) -> None:
"""Add entries to the search table
Args:
- txn (cursor):
- entries (iterable[SearchEntry]):
- entries to be added to the table
+ txn:
+ entries: entries to be added to the table
"""
if not self.hs.config.enable_search:
return
@@ -55,7 +65,7 @@ class SearchWorkerStore(SQLBaseStore):
entry.event_id,
entry.room_id,
entry.key,
- entry.value,
+ _clean_value_for_search(entry.value),
entry.stream_ordering,
entry.origin_server_ts,
)
@@ -70,11 +80,16 @@ class SearchWorkerStore(SQLBaseStore):
" VALUES (?,?,?,?)"
)
args = (
- (entry.event_id, entry.room_id, entry.key, entry.value)
+ (
+ entry.event_id,
+ entry.room_id,
+ entry.key,
+ _clean_value_for_search(entry.value),
+ )
for entry in entries
)
-
txn.execute_batch(sql, args)
+
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@@ -646,6 +661,7 @@ class SearchStore(SearchBackgroundUpdateStore):
for key in ("body", "name", "topic"):
v = event.content.get(key, None)
if v:
+ v = _clean_value_for_search(v)
values.append(v)
if not values:
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index bff7d0404f..a89747d741 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -58,7 +58,7 @@ class StateDeltasStore(SQLBaseStore):
# if the CSDs haven't changed between prev_stream_id and now, we
# know for certain that they haven't changed between prev_stream_id and
# max_stream_id.
- return (max_stream_id, [])
+ return max_stream_id, []
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 343d6efc92..e20033bb28 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -98,7 +98,7 @@ class StatsStore(StateDeltasStore):
self.server_name = hs.hostname
self.clock = self.hs.get_clock()
- self.stats_enabled = hs.config.stats_enabled
+ self.stats_enabled = hs.config.stats.stats_enabled
self.stats_delta_processing_lock = DeferredLock()
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 959f13de47..dc7884b1c0 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -39,6 +39,8 @@ import logging
from collections import namedtuple
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
+from frozendict import frozendict
+
from twisted.internet import defer
from synapse.api.filtering import Filter
@@ -379,7 +381,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if p > min_pos
}
- return RoomStreamToken(None, min_pos, positions)
+ return RoomStreamToken(None, min_pos, frozendict(positions))
async def get_room_events_stream_for_rooms(
self,
@@ -622,7 +624,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
self._set_before_and_after(events, rows)
- return (events, token)
+ return events, token
async def get_recent_event_ids_for_room(
self, room_id: str, limit: int, end_token: RoomStreamToken
@@ -1240,7 +1242,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
self._set_before_and_after(events, rows)
- return (events, token)
+ return events, token
@cached()
async def get_id_for_instance(self, instance_name: str) -> int:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 718f3e9976..90d65edc42 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -14,14 +14,28 @@
import logging
import re
-from typing import Any, Dict, Iterable, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ cast,
+)
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import get_domain_from_id, get_localpart_from_id
+from synapse.storage.types import Connection
+from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -36,7 +50,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: Connection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -57,10 +76,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"populate_user_directory_cleanup", self._populate_user_directory_cleanup
)
- async def _populate_user_directory_createtables(self, progress, batch_size):
+ async def _populate_user_directory_createtables(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
# Get all the rooms that we want to process.
- def _make_staging_area(txn):
+ def _make_staging_area(txn: LoggingTransaction) -> None:
sql = (
"CREATE TABLE IF NOT EXISTS "
+ TEMP_TABLE
@@ -110,16 +131,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
return 1
- async def _populate_user_directory_cleanup(self, progress, batch_size):
+ async def _populate_user_directory_cleanup(
+ self,
+ progress: JsonDict,
+ batch_size: int,
+ ) -> int:
"""
Update the user directory stream position, then clean up the old tables.
"""
position = await self.db_pool.simple_select_one_onecol(
- TEMP_TABLE + "_position", None, "position"
+ TEMP_TABLE + "_position", {}, "position"
)
await self.update_user_directory_stream_pos(position)
- def _delete_staging_area(txn):
+ def _delete_staging_area(txn: LoggingTransaction) -> None:
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
@@ -133,18 +158,32 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
return 1
- async def _populate_user_directory_process_rooms(self, progress, batch_size):
+ async def _populate_user_directory_process_rooms(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""
+ Rescan the state of all rooms so we can track
+
+ - who's in a public room;
+ - which local users share a private room with other users (local
+ and remote); and
+ - who should be in the user_directory.
+
Args:
progress (dict)
batch_size (int): Maximum number of state events to process
per cycle.
+
+ Returns:
+ number of events processed.
"""
# If we don't have progress filed, delete everything.
if not progress:
await self.delete_all_from_user_dir()
- def _get_next_batch(txn):
+ def _get_next_batch(
+ txn: LoggingTransaction,
+ ) -> Optional[Sequence[Tuple[str, int]]]:
# Only fetch 250 rooms, so we don't fetch too many at once, even
# if those 250 rooms have less than batch_size state events.
sql = """
@@ -155,7 +194,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
TEMP_TABLE + "_rooms",
)
txn.execute(sql)
- rooms_to_work_on = txn.fetchall()
+ rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall())
if not rooms_to_work_on:
return None
@@ -163,7 +202,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# Get how many are left to process, so we can give status on how
# far we are in processing
txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
- progress["remaining"] = txn.fetchone()[0]
+ result = txn.fetchone()
+ assert result is not None
+ progress["remaining"] = result[0]
return rooms_to_work_on
@@ -261,29 +302,33 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return processed_event_count
- async def _populate_user_directory_process_users(self, progress, batch_size):
+ async def _populate_user_directory_process_users(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""
Add all local users to the user directory.
"""
- def _get_next_batch(txn):
+ def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]:
sql = "SELECT user_id FROM %s LIMIT %s" % (
TEMP_TABLE + "_users",
str(batch_size),
)
txn.execute(sql)
- users_to_work_on = txn.fetchall()
+ user_result = cast(List[Tuple[str]], txn.fetchall())
- if not users_to_work_on:
+ if not user_result:
return None
- users_to_work_on = [x[0] for x in users_to_work_on]
+ users_to_work_on = [x[0] for x in user_result]
# Get how many are left to process, so we can give status on how
# far we are in processing
sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users"
txn.execute(sql)
- progress["remaining"] = txn.fetchone()[0]
+ count_result = txn.fetchone()
+ assert count_result is not None
+ progress["remaining"] = count_result[0]
return users_to_work_on
@@ -324,7 +369,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return len(users_to_work_on)
- async def is_room_world_readable_or_publicly_joinable(self, room_id):
+ async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
"""Check if the room is either world_readable or publically joinable"""
# Create a state filter that only queries join and history state event
@@ -368,7 +413,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
if not isinstance(avatar_url, str):
avatar_url = None
- def _update_profile_in_user_dir_txn(txn):
+ def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_upsert_txn(
txn,
table="user_directory",
@@ -435,7 +480,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
for user_id, other_user_id in user_id_tuples
],
value_names=(),
- value_values=None,
+ value_values=(),
desc="add_users_who_share_room",
)
@@ -454,14 +499,14 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
key_names=["user_id", "room_id"],
key_values=[(user_id, room_id) for user_id in user_ids],
value_names=(),
- value_values=None,
+ value_values=(),
desc="add_users_in_public_rooms",
)
async def delete_all_from_user_dir(self) -> None:
"""Delete the entire user directory"""
- def _delete_all_from_user_dir_txn(txn):
+ def _delete_all_from_user_dir_txn(txn: LoggingTransaction) -> None:
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_public_rooms")
@@ -473,7 +518,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
@cached()
- async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
+ async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
@@ -497,16 +542,21 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: Connection,
+ hs: "HomeServer",
+ ) -> None:
super().__init__(database, db_conn, hs)
self._prefer_local_users_in_search = (
- hs.config.user_directory_search_prefer_local_users
+ hs.config.userdirectory.user_directory_search_prefer_local_users
)
self._server_name = hs.config.server.server_name
async def remove_from_user_dir(self, user_id: str) -> None:
- def _remove_from_user_dir_txn(txn):
+ def _remove_from_user_dir_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
@@ -532,7 +582,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"remove_from_user_dir", _remove_from_user_dir_txn
)
- async def get_users_in_dir_due_to_room(self, room_id):
+ async def get_users_in_dir_due_to_room(self, room_id: str) -> Set[str]:
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
@@ -565,7 +615,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
room_id
"""
- def _remove_user_who_share_room_txn(txn):
+ def _remove_user_who_share_room_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
@@ -586,7 +636,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
- async def get_user_dir_rooms_user_is_in(self, user_id):
+ async def get_user_dir_rooms_user_is_in(self, user_id: str) -> List[str]:
"""
Returns the rooms that a user is in.
@@ -628,7 +678,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
A set of room ID's that the users share.
"""
- def _get_shared_rooms_for_users_txn(txn):
+ def _get_shared_rooms_for_users_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, str]]:
txn.execute(
"""
SELECT p1.room_id
@@ -669,7 +721,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
desc="get_user_directory_stream_pos",
)
- async def search_user_dir(self, user_id, search_term, limit):
+ async def search_user_dir(
+ self, user_id: str, search_term: str, limit: int
+ ) -> JsonDict:
"""Searches for users in directory
Returns:
@@ -687,7 +741,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
}
"""
- if self.hs.config.user_directory_search_all_users:
+ if self.hs.config.userdirectory.user_directory_search_all_users:
join_args = (user_id,)
where_clause = "user_id != ?"
else:
@@ -705,7 +759,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# We allow manipulating the ranking algorithm by injecting statements
# based on config options.
additional_ordering_statements = []
- ordering_arguments = ()
+ ordering_arguments: Tuple[str, ...] = ()
if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
@@ -811,7 +865,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return {"limited": limited, "results": results}
-def _parse_query_sqlite(search_term):
+def _parse_query_sqlite(search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
@@ -826,7 +880,7 @@ def _parse_query_sqlite(search_term):
return " & ".join("(%s* OR %s)" % (result, result) for result in results)
-def _parse_query_postgres(search_term):
+def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
|