diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 101403578c..89f4077351 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,7 +15,9 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
+
+import attr
from synapse.api.constants import Direction
from synapse.config.homeserver import HomeServerConfig
@@ -28,7 +30,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Cursor
-from synapse.types import JsonDict, get_domain_from_id
+from synapse.types import get_domain_from_id
from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
@@ -82,6 +84,25 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UserPaginateResponse:
+ """This is very similar to UserInfo, but not quite the same."""
+
+ name: str
+ user_type: Optional[str]
+ is_guest: bool
+ admin: bool
+ deactivated: bool
+ shadow_banned: bool
+ displayname: Optional[str]
+ avatar_url: Optional[str]
+ creation_ts: Optional[int]
+ approved: bool
+ erased: bool
+ last_seen_ts: int
+ locked: bool
+
+
class DataStore(
EventsBackgroundUpdatesStore,
ExperimentalFeaturesStore,
@@ -142,26 +163,6 @@ class DataStore(
super().__init__(database, db_conn, hs)
- async def get_users(self) -> List[JsonDict]:
- """Function to retrieve a list of users in users table.
-
- Returns:
- A list of dictionaries representing users.
- """
- return await self.db_pool.simple_select_list(
- table="users",
- keyvalues={},
- retcols=[
- "name",
- "password_hash",
- "is_guest",
- "admin",
- "user_type",
- "deactivated",
- ],
- desc="get_users",
- )
-
async def get_users_paginate(
self,
start: int,
@@ -176,7 +177,7 @@ class DataStore(
approved: bool = True,
not_user_types: Optional[List[str]] = None,
locked: bool = False,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[UserPaginateResponse], int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
total number of users matching the filter criteria.
@@ -202,7 +203,7 @@ class DataStore(
def get_users_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[UserPaginateResponse], int]:
filters = []
args: list = []
@@ -302,13 +303,24 @@ class DataStore(
"""
args += [limit, start]
txn.execute(sql, args)
- users = self.db_pool.cursor_to_dict(txn)
-
- # some of those boolean values are returned as integers when we're on SQLite
- columns_to_boolify = ["erased"]
- for user in users:
- for column in columns_to_boolify:
- user[column] = bool(user[column])
+ users = [
+ UserPaginateResponse(
+ name=row[0],
+ user_type=row[1],
+ is_guest=bool(row[2]),
+ admin=bool(row[3]),
+ deactivated=bool(row[4]),
+ shadow_banned=bool(row[5]),
+ displayname=row[6],
+ avatar_url=row[7],
+ creation_ts=row[8],
+ approved=bool(row[9]),
+ erased=bool(row[10]),
+ last_seen_ts=row[11],
+ locked=bool(row[12]),
+ )
+ for row in txn
+ ]
return users, count
@@ -316,7 +328,11 @@ class DataStore(
"get_users_paginate_txn", get_users_paginate_txn
)
- async def search_users(self, term: str) -> Optional[List[JsonDict]]:
+ async def search_users(
+ self, term: str
+ ) -> List[
+ Tuple[str, Optional[str], Union[int, bool], Union[int, bool], Optional[str]]
+ ]:
"""Function to search users list for one or more users with
the matched term.
@@ -324,15 +340,37 @@ class DataStore(
term: search term
Returns:
- A list of dictionaries or None.
+ A list of tuples of name, password_hash, is_guest, admin, user_type or None.
"""
- return await self.db_pool.simple_search_list(
- table="users",
- term=term,
- col="name",
- retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
- desc="search_users",
- )
+
+ def search_users(
+ txn: LoggingTransaction,
+ ) -> List[
+ Tuple[str, Optional[str], Union[int, bool], Union[int, bool], Optional[str]]
+ ]:
+ search_term = "%%" + term + "%%"
+
+ sql = """
+ SELECT name, password_hash, is_guest, admin, user_type
+ FROM users
+ WHERE name LIKE ?
+ """
+ txn.execute(sql, (search_term,))
+
+ return cast(
+ List[
+ Tuple[
+ str,
+ Optional[str],
+ Union[int, bool],
+ Union[int, bool],
+ Optional[str],
+ ]
+ ],
+ txn.fetchall(),
+ )
+
+ return await self.db_pool.runInteraction("search_users", search_users)
def check_database_before_upgrade(
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 80f146dd53..d7482a1f4e 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -94,7 +94,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
hs.get_replication_notifier(),
"room_account_data",
"stream_id",
- extra_tables=[("room_tags_revisions", "stream_id")],
+ extra_tables=[
+ ("account_data", "stream_id"),
+ ("room_tags_revisions", "stream_id"),
+ ],
is_writer=self._instance_name in hs.config.worker.writers.account_data,
)
@@ -103,6 +106,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
"AccountDataAndTagsChangeCache", account_max
)
+ self.db_pool.updates.register_background_index_update(
+ update_name="room_account_data_index_room_id",
+ index_name="room_account_data_room_id",
+ table="room_account_data",
+ columns=("room_id",),
+ )
+
self.db_pool.updates.register_background_update_handler(
"delete_account_data_for_deactivated_users",
self._delete_account_data_for_deactivated_users,
@@ -151,10 +161,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
sql += " AND content != '{}'"
txn.execute(sql, (user_id,))
- rows = self.db_pool.cursor_to_dict(txn)
return {
- row["account_data_type"]: db_to_json(row["content"]) for row in rows
+ account_data_type: db_to_json(content)
+ for account_data_type, content in txn
}
return await self.db_pool.runInteraction(
@@ -196,13 +206,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
sql += " AND content != '{}'"
txn.execute(sql, (user_id,))
- rows = self.db_pool.cursor_to_dict(txn)
by_room: Dict[str, Dict[str, JsonDict]] = {}
- for row in rows:
- room_data = by_room.setdefault(row["room_id"], {})
+ for room_id, account_data_type, content in txn:
+ room_data = by_room.setdefault(room_id, {})
- room_data[row["account_data_type"]] = db_to_json(row["content"])
+ room_data[account_data_type] = db_to_json(content)
return by_room
@@ -277,16 +286,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def get_account_data_for_room_txn(
txn: LoggingTransaction,
- ) -> Dict[str, JsonDict]:
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "room_account_data",
- {"user_id": user_id, "room_id": room_id},
- ["account_data_type", "content"],
+ ) -> Dict[str, JsonMapping]:
+ rows = cast(
+ List[Tuple[str, str]],
+ self.db_pool.simple_select_list_txn(
+ txn,
+ table="room_account_data",
+ keyvalues={"user_id": user_id, "room_id": room_id},
+ retcols=["account_data_type", "content"],
+ ),
)
return {
- row["account_data_type"]: db_to_json(row["content"]) for row in rows
+ account_data_type: db_to_json(content)
+ for account_data_type, content in rows
}
return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 0553a0621a..fa7d1c469a 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -14,17 +14,7 @@
# limitations under the License.
import logging
import re
-from typing import (
- TYPE_CHECKING,
- Any,
- Dict,
- List,
- Optional,
- Pattern,
- Sequence,
- Tuple,
- cast,
-)
+from typing import TYPE_CHECKING, List, Optional, Pattern, Sequence, Tuple, cast
from synapse.appservice import (
ApplicationService,
@@ -207,16 +197,21 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A list of ApplicationServices, which may be empty.
"""
- results = await self.db_pool.simple_select_list(
- "application_services_state", {"state": state.value}, ["as_id"]
+ results = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_list(
+ table="application_services_state",
+ keyvalues={"state": state.value},
+ retcols=("as_id",),
+ ),
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
services = []
- for res in results:
+ for (as_id,) in results:
for service in as_list:
- if service.id == res["as_id"]:
+ if service.id == as_id:
services.append(service)
return services
@@ -353,21 +348,15 @@ class ApplicationServiceTransactionWorkerStore(
def _get_oldest_unsent_txn(
txn: LoggingTransaction,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[Tuple[int, str]]:
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
- "SELECT * FROM application_services_txns WHERE as_id=?"
+ "SELECT txn_id, event_ids FROM application_services_txns WHERE as_id=?"
" ORDER BY txn_id ASC LIMIT 1",
(service.id,),
)
- rows = self.db_pool.cursor_to_dict(txn)
- if not rows:
- return None
-
- entry = rows[0]
-
- return entry
+ return cast(Optional[Tuple[int, str]], txn.fetchone())
entry = await self.db_pool.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
@@ -376,8 +365,9 @@ class ApplicationServiceTransactionWorkerStore(
if not entry:
return None
- event_ids = db_to_json(entry["event_ids"])
+ txn_id, event_ids_str = entry
+ event_ids = db_to_json(event_ids_str)
events = await self.get_events_as_list(event_ids)
# TODO: to-device messages, one-time key counts, device list summaries and unused
@@ -385,7 +375,7 @@ class ApplicationServiceTransactionWorkerStore(
# We likely want to populate those for reliability.
return AppServiceTransaction(
service=service,
- id=entry["txn_id"],
+ id=txn_id,
events=events,
ephemeral=[],
to_device_messages=[],
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2fbd389c71..4d0470ffd9 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -23,6 +23,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.replication.tcp.streams import BackfillStream, CachesStream
from synapse.replication.tcp.streams.events import (
EventsStream,
+ EventsStreamAllStateRow,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
@@ -264,6 +265,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
(data.state_key,)
)
self.get_rooms_for_user.invalidate((data.state_key,)) # type: ignore[attr-defined]
+ elif row.type == EventsStreamAllStateRow.TypeId:
+ assert isinstance(data, EventsStreamAllStateRow)
+ # Similar to the above, but the entire caches are invalidated. This is
+ # unfortunate for the membership caches, but should recover quickly.
+ self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) # type: ignore[attr-defined]
+ self.get_rooms_for_user_with_stream_ordering.invalidate_all() # type: ignore[attr-defined]
+ self.get_rooms_for_user.invalidate_all() # type: ignore[attr-defined]
else:
raise Exception("Unknown events stream row type %s" % (row.type,))
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 58177ecec1..711fdddd4e 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -93,7 +93,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
"""
rows = await self.db_pool.execute(
- "_censor_redactions_fetch", None, sql, before_ts, 100
+ "_censor_redactions_fetch", sql, before_ts, 100
)
updates = []
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 7da47c3dd7..c006129625 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast
+import attr
from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -42,7 +43,8 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000
-class DeviceLastConnectionInfo(TypedDict):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class DeviceLastConnectionInfo:
"""Metadata for the last connection seen for a user and device combination"""
# These types must match the columns in the `devices` table
@@ -499,8 +501,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
device_id: If None fetches all devices for the user
Returns:
- A dictionary mapping a tuple of (user_id, device_id) to dicts, with
- keys giving the column names from the devices table.
+ A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo.
"""
keyvalues = {"user_id": user_id}
@@ -508,7 +509,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
keyvalues["device_id"] = device_id
res = cast(
- List[DeviceLastConnectionInfo],
+ List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
@@ -516,7 +517,16 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
),
)
- return {(d["user_id"], d["device_id"]): d for d in res}
+ return {
+ (user_id, device_id): DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id,
+ ip=ip,
+ user_agent=user_agent,
+ last_seen=last_seen,
+ )
+ for user_id, ip, user_agent, device_id, last_seen in res
+ }
async def _get_user_ip_and_agents_from_database(
self, user: UserID, since_ts: int = 0
@@ -683,8 +693,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
device_id: If None fetches all devices for the user
Returns:
- A dictionary mapping a tuple of (user_id, device_id) to dicts, with
- keys giving the column names from the devices table.
+ A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo.
"""
ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id)
@@ -705,13 +714,13 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
continue
if not device_id or did == device_id:
- ret[(user_id, did)] = {
- "user_id": user_id,
- "ip": ip,
- "user_agent": user_agent,
- "device_id": did,
- "last_seen": last_seen,
- }
+ ret[(user_id, did)] = DeviceLastConnectionInfo(
+ user_id=user_id,
+ ip=ip,
+ user_agent=user_agent,
+ device_id=did,
+ last_seen=last_seen,
+ )
return ret
async def get_user_ip_and_agents(
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 744e98c6d0..3e7425d4a6 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -344,18 +344,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# Note that this is more efficient than just dropping `device_id` from the query,
# since device_inbox has an index on `(user_id, device_id, stream_id)`
if not device_ids_to_query:
- user_device_dicts = self.db_pool.simple_select_many_txn(
- txn,
- table="devices",
- column="user_id",
- iterable=user_ids_to_query,
- keyvalues={"hidden": False},
- retcols=("device_id",),
+ user_device_dicts = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="devices",
+ column="user_id",
+ iterable=user_ids_to_query,
+ keyvalues={"hidden": False},
+ retcols=("device_id",),
+ ),
)
- device_ids_to_query.update(
- {row["device_id"] for row in user_device_dicts}
- )
+ device_ids_to_query.update({row[0] for row in user_device_dicts})
if not device_ids_to_query:
# We've ended up with no devices to query.
@@ -449,7 +450,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
user_id: str,
device_id: Optional[str],
up_to_stream_id: int,
- limit: int,
+ limit: Optional[int] = None,
) -> int:
"""
Args:
@@ -477,17 +478,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
log_kv({"message": "No changes in cache since last check"})
return 0
- ROW_ID_NAME = self.database_engine.row_id_name
-
def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
+ limit_statement = "" if limit is None else f"LIMIT {limit}"
sql = f"""
- DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN (
- SELECT {ROW_ID_NAME} FROM device_inbox
- WHERE user_id = ? AND device_id = ? AND stream_id <= ?
- LIMIT {limit}
+ DELETE FROM device_inbox WHERE user_id = ? AND device_id = ? AND stream_id <= (
+ SELECT MAX(stream_id) FROM (
+ SELECT stream_id FROM device_inbox
+ WHERE user_id = ? AND device_id = ? AND stream_id <= ?
+ ORDER BY stream_id
+ {limit_statement}
+ ) AS q1
)
"""
- txn.execute(sql, (user_id, device_id, up_to_stream_id))
+ txn.execute(sql, (user_id, device_id, user_id, device_id, up_to_stream_id))
return txn.rowcount
count = await self.db_pool.runInteraction(
@@ -845,20 +848,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# We exclude hidden devices (such as cross-signing keys) here as they are
# not expected to receive to-device messages.
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="devices",
- keyvalues={"user_id": user_id, "hidden": False},
- column="device_id",
- iterable=devices,
- retcols=("device_id",),
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="devices",
+ keyvalues={"user_id": user_id, "hidden": False},
+ column="device_id",
+ iterable=devices,
+ retcols=("device_id",),
+ ),
)
- for row in rows:
+ for (device_id,) in rows:
# Only insert into the local inbox if the device exists on
# this server
- device_id = row["device_id"]
-
with start_active_span("serialise_to_device_message"):
msg = messages_by_device[device_id]
set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"])
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index df596f35f9..04d12a876c 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -283,7 +283,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
allow_none=True,
)
- async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
+ async def get_devices_by_user(
+ self, user_id: str
+ ) -> Dict[str, Dict[str, Optional[str]]]:
"""Retrieve all of a user's registered devices. Only returns devices
that are not marked as hidden.
@@ -291,20 +293,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_id:
Returns:
A mapping from device_id to a dict containing "device_id", "user_id"
- and "display_name" for each device.
+ and "display_name" for each device. Display name may be null.
"""
- devices = await self.db_pool.simple_select_list(
- table="devices",
- keyvalues={"user_id": user_id, "hidden": False},
- retcols=("user_id", "device_id", "display_name"),
- desc="get_devices_by_user",
+ devices = cast(
+ List[Tuple[str, str, Optional[str]]],
+ await self.db_pool.simple_select_list(
+ table="devices",
+ keyvalues={"user_id": user_id, "hidden": False},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_devices_by_user",
+ ),
)
- return {d["device_id"]: d for d in devices}
+ return {
+ d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]}
+ for d in devices
+ }
async def get_devices_by_auth_provider_session_id(
self, auth_provider_id: str, auth_provider_session_id: str
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[str, str]]:
"""Retrieve the list of devices associated with a SSO IdP session ID.
Args:
@@ -313,14 +321,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
Returns:
A list of dicts containing the device_id and the user_id of each device
"""
- return await self.db_pool.simple_select_list(
- table="device_auth_providers",
- keyvalues={
- "auth_provider_id": auth_provider_id,
- "auth_provider_session_id": auth_provider_session_id,
- },
- retcols=("user_id", "device_id"),
- desc="get_devices_by_auth_provider_session_id",
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="device_auth_providers",
+ keyvalues={
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
+ },
+ retcols=("user_id", "device_id"),
+ desc="get_devices_by_auth_provider_session_id",
+ ),
)
@trace
@@ -692,7 +703,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
key_names=("destination", "user_id"),
key_values=[(destination, user_id) for user_id, _ in rows],
value_names=("stream_id",),
- value_values=((stream_id,) for _, stream_id in rows),
+ value_values=[(stream_id,) for _, stream_id in rows],
)
# Delete all sent outbound pokes
@@ -821,15 +832,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
async def get_cached_devices_for_user(
self, user_id: str
) -> Mapping[str, JsonMapping]:
- devices = await self.db_pool.simple_select_list(
- table="device_lists_remote_cache",
- keyvalues={"user_id": user_id},
- retcols=("device_id", "content"),
- desc="get_cached_devices_for_user",
+ devices = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="device_lists_remote_cache",
+ keyvalues={"user_id": user_id},
+ retcols=("device_id", "content"),
+ desc="get_cached_devices_for_user",
+ ),
)
- return {
- device["device_id"]: db_to_json(device["content"]) for device in devices
- }
+ return {device[0]: db_to_json(device[1]) for device in devices}
def get_cached_device_list_changes(
self,
@@ -882,7 +894,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
rows = await self.db_pool.execute(
"get_all_devices_changed",
- None,
sql,
from_key,
to_key,
@@ -966,7 +977,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
WHERE from_user_id = ? AND stream_id > ?
"""
rows = await self.db_pool.execute(
- "get_users_whose_signatures_changed", None, sql, user_id, from_key
+ "get_users_whose_signatures_changed", sql, user_id, from_key
)
return {user for row in rows for user in db_to_json(row[0])}
else:
@@ -1052,16 +1063,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
async def get_device_list_last_stream_id_for_remotes(
self, user_ids: Iterable[str]
) -> Mapping[str, Optional[str]]:
- rows = await self.db_pool.simple_select_many_batch(
- table="device_lists_remote_extremeties",
- column="user_id",
- iterable=user_ids,
- retcols=("user_id", "stream_id"),
- desc="get_device_list_last_stream_id_for_remotes",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="device_lists_remote_extremeties",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id", "stream_id"),
+ desc="get_device_list_last_stream_id_for_remotes",
+ ),
)
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
- results.update({row["user_id"]: row["stream_id"] for row in rows})
+ results.update(rows)
return results
@@ -1077,22 +1091,28 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
The IDs of users whose device lists need resync.
"""
if user_ids:
- rows = await self.db_pool.simple_select_many_batch(
- table="device_lists_remote_resync",
- column="user_id",
- iterable=user_ids,
- retcols=("user_id",),
- desc="get_user_ids_requiring_device_list_resync_with_iterable",
+ rows = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="device_lists_remote_resync",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id",),
+ desc="get_user_ids_requiring_device_list_resync_with_iterable",
+ ),
)
else:
- rows = await self.db_pool.simple_select_list(
- table="device_lists_remote_resync",
- keyvalues=None,
- retcols=("user_id",),
- desc="get_user_ids_requiring_device_list_resync",
+ rows = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_list(
+ table="device_lists_remote_resync",
+ keyvalues=None,
+ retcols=("user_id",),
+ desc="get_user_ids_requiring_device_list_resync",
+ ),
)
- return {row["user_id"] for row in rows}
+ return {row[0] for row in rows}
async def mark_remote_users_device_caches_as_stale(
self, user_ids: StrCollection
@@ -1413,13 +1433,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_devices_not_accessed_since_txn(
txn: LoggingTransaction,
- ) -> List[Dict[str, str]]:
+ ) -> List[Tuple[str, str]]:
sql = """
SELECT user_id, device_id
FROM devices WHERE last_seen < ? AND hidden = FALSE
"""
txn.execute(sql, (since_ms,))
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[Tuple[str, str]], txn.fetchall())
rows = await self.db_pool.runInteraction(
"get_devices_not_accessed_since",
@@ -1427,11 +1447,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
devices: Dict[str, List[str]] = {}
- for row in rows:
+ for user_id, device_id in rows:
# Remote devices are never stale from our point of view.
- if self.hs.is_mine_id(row["user_id"]):
- user_devices = devices.setdefault(row["user_id"], [])
- user_devices.append(row["device_id"])
+ if self.hs.is_mine_id(user_id):
+ user_devices = devices.setdefault(user_id, [])
+ user_devices.append(device_id)
return devices
@@ -1600,7 +1620,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
#
# For each duplicate, we delete all the existing rows and put one back.
- KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
last_row = progress.get(
"last_row",
{"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
@@ -1608,44 +1627,62 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
def _txn(txn: LoggingTransaction) -> int:
clause, args = make_tuple_comparison_clause(
- [(x, last_row[x]) for x in KEY_COLS]
+ [
+ ("stream_id", last_row["stream_id"]),
+ ("destination", last_row["destination"]),
+ ("user_id", last_row["user_id"]),
+ ("device_id", last_row["device_id"]),
+ ]
)
- sql = """
+ sql = f"""
SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
FROM device_lists_outbound_pokes
- WHERE %s
- GROUP BY %s
+ WHERE {clause}
+ GROUP BY stream_id, destination, user_id, device_id
HAVING count(*) > 1
- ORDER BY %s
+ ORDER BY stream_id, destination, user_id, device_id
LIMIT ?
- """ % (
- clause, # WHERE
- ",".join(KEY_COLS), # GROUP BY
- ",".join(KEY_COLS), # ORDER BY
- )
+ """
txn.execute(sql, args + [batch_size])
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
- row = None
- for row in rows:
+ stream_id, destination, user_id, device_id = None, None, None, None
+ for stream_id, destination, user_id, device_id, _ in rows:
self.db_pool.simple_delete_txn(
txn,
"device_lists_outbound_pokes",
- {x: row[x] for x in KEY_COLS},
+ {
+ "stream_id": stream_id,
+ "destination": destination,
+ "user_id": user_id,
+ "device_id": device_id,
+ },
)
- row["sent"] = False
self.db_pool.simple_insert_txn(
txn,
"device_lists_outbound_pokes",
- row,
+ {
+ "stream_id": stream_id,
+ "destination": destination,
+ "user_id": user_id,
+ "device_id": device_id,
+ "sent": False,
+ },
)
- if row:
+ if rows:
self.db_pool.updates._background_update_progress_txn(
txn,
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
- {"last_row": row},
+ {
+ "last_row": {
+ "stream_id": stream_id,
+ "destination": destination,
+ "user_id": user_id,
+ "device_id": device_id,
+ }
+ },
)
return len(rows)
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index d01f28cc80..ad904a26a6 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast
from typing_extensions import Literal, TypedDict
@@ -53,6 +53,13 @@ class EndToEndRoomKeyBackgroundStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
+ self.db_pool.updates.register_background_index_update(
+ update_name="e2e_room_keys_index_room_id",
+ index_name="e2e_room_keys_room_id",
+ table="e2e_room_keys",
+ columns=("room_id",),
+ )
+
self.db_pool.updates.register_background_update_handler(
"delete_e2e_backup_keys_for_deactivated_users",
self._delete_e2e_backup_keys_for_deactivated_users,
@@ -208,7 +215,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
"message": "Set room key",
"room_id": room_id,
"session_id": session_id,
- StreamKeyType.ROOM: room_key,
+ StreamKeyType.ROOM.value: room_key,
}
)
@@ -267,32 +274,41 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
if session_id:
keyvalues["session_id"] = session_id
- rows = await self.db_pool.simple_select_list(
- table="e2e_room_keys",
- keyvalues=keyvalues,
- retcols=(
- "user_id",
- "room_id",
- "session_id",
- "first_message_index",
- "forwarded_count",
- "is_verified",
- "session_data",
+ rows = cast(
+ List[Tuple[str, str, int, int, int, str]],
+ await self.db_pool.simple_select_list(
+ table="e2e_room_keys",
+ keyvalues=keyvalues,
+ retcols=(
+ "room_id",
+ "session_id",
+ "first_message_index",
+ "forwarded_count",
+ "is_verified",
+ "session_data",
+ ),
+ desc="get_e2e_room_keys",
),
- desc="get_e2e_room_keys",
)
sessions: Dict[
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
] = {"rooms": {}}
- for row in rows:
- room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
- room_entry["sessions"][row["session_id"]] = {
- "first_message_index": row["first_message_index"],
- "forwarded_count": row["forwarded_count"],
+ for (
+ room_id,
+ session_id,
+ first_message_index,
+ forwarded_count,
+ is_verified,
+ session_data,
+ ) in rows:
+ room_entry = sessions["rooms"].setdefault(room_id, {"sessions": {}})
+ room_entry["sessions"][session_id] = {
+ "first_message_index": first_message_index,
+ "forwarded_count": forwarded_count,
# is_verified must be returned to the client as a boolean
- "is_verified": bool(row["is_verified"]),
- "session_data": db_to_json(row["session_data"]),
+ "is_verified": bool(is_verified),
+ "session_data": db_to_json(session_data),
}
return sessions
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 89fac23f93..4f96ac25c7 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -24,6 +24,7 @@ from typing import (
Mapping,
Optional,
Sequence,
+ Set,
Tuple,
Union,
cast,
@@ -155,7 +156,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"""
rows = await self.db_pool.execute(
"get_e2e_device_keys_for_federation_query_check",
- None,
sql,
now_stream_id,
user_id,
@@ -493,15 +493,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
A map from (algorithm, key_id) to json string for key
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="e2e_one_time_keys_json",
- column="key_id",
- iterable=key_ids,
- retcols=("algorithm", "key_id", "key_json"),
- keyvalues={"user_id": user_id, "device_id": device_id},
- desc="add_e2e_one_time_keys_check",
+ rows = cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="e2e_one_time_keys_json",
+ column="key_id",
+ iterable=key_ids,
+ retcols=("algorithm", "key_id", "key_json"),
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ desc="add_e2e_one_time_keys_check",
+ ),
)
- result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
+ result = {(algorithm, key_id): key_json for algorithm, key_id, key_json in rows}
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
return result
@@ -921,14 +924,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
}
txn.execute(sql, params)
- rows = self.db_pool.cursor_to_dict(txn)
- for row in rows:
- user_id = row["user_id"]
- key_type = row["keytype"]
- key = db_to_json(row["keydata"])
+ for user_id, key_type, key_data, _ in txn:
user_keys = result.setdefault(user_id, {})
- user_keys[key_type] = key
+ user_keys[key_type] = db_to_json(key_data)
return result
@@ -988,13 +987,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
query_params.extend(item)
txn.execute(sql, query_params)
- rows = self.db_pool.cursor_to_dict(txn)
# and add the signatures to the appropriate keys
- for row in rows:
- key_id: str = row["key_id"]
- target_user_id: str = row["target_user_id"]
- target_device_id: str = row["target_device_id"]
+ for target_user_id, target_device_id, key_id, signature in txn:
key_type = devices[(target_user_id, target_device_id)]
# We need to copy everything, because the result may have come
# from the cache. dict.copy only does a shallow copy, so we
@@ -1012,13 +1007,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
].copy()
if from_user_id in signatures:
user_sigs = signatures[from_user_id] = signatures[from_user_id]
- user_sigs[key_id] = row["signature"]
+ user_sigs[key_id] = signature
else:
- signatures[from_user_id] = {key_id: row["signature"]}
+ signatures[from_user_id] = {key_id: signature}
else:
- target_user_key["signatures"] = {
- from_user_id: {key_id: row["signature"]}
- }
+ target_user_key["signatures"] = {from_user_id: {key_id: signature}}
return keys
@@ -1118,7 +1111,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
...
async def claim_e2e_one_time_keys(
- self, query_list: Iterable[Tuple[str, str, str, int]]
+ self, query_list: Collection[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
@@ -1128,131 +1121,63 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
query_list: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
- A tuple pf:
+ A tuple (results, missing) of:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
- A copy of the input which has not been fulfilled.
+ A copy of the input which has not been fulfilled. The returned counts
+ may be less than the input counts. In this case, the returned counts
+ are the number of claims that were not fulfilled.
"""
-
- @trace
- def _claim_e2e_one_time_key_simple(
- txn: LoggingTransaction,
- user_id: str,
- device_id: str,
- algorithm: str,
- count: int,
- ) -> List[Tuple[str, str]]:
- """Claim OTK for device for DBs that don't support RETURNING.
-
- Returns:
- A tuple of key name (algorithm + key ID) and key JSON, if an
- OTK was found.
- """
-
- sql = """
- SELECT key_id, key_json FROM e2e_one_time_keys_json
- WHERE user_id = ? AND device_id = ? AND algorithm = ?
- LIMIT ?
- """
-
- txn.execute(sql, (user_id, device_id, algorithm, count))
- otk_rows = list(txn)
- if not otk_rows:
- return []
-
- self.db_pool.simple_delete_many_txn(
- txn,
- table="e2e_one_time_keys_json",
- column="key_id",
- values=[otk_row[0] for otk_row in otk_rows],
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- "algorithm": algorithm,
- },
- )
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
- return [
- (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
- ]
-
- @trace
- def _claim_e2e_one_time_key_returning(
- txn: LoggingTransaction,
- user_id: str,
- device_id: str,
- algorithm: str,
- count: int,
- ) -> List[Tuple[str, str]]:
- """Claim OTK for device for DBs that support RETURNING.
-
- Returns:
- A tuple of key name (algorithm + key ID) and key JSON, if an
- OTK was found.
- """
-
- # We can use RETURNING to do the fetch and DELETE in once step.
- sql = """
- DELETE FROM e2e_one_time_keys_json
- WHERE user_id = ? AND device_id = ? AND algorithm = ?
- AND key_id IN (
- SELECT key_id FROM e2e_one_time_keys_json
- WHERE user_id = ? AND device_id = ? AND algorithm = ?
- LIMIT ?
- )
- RETURNING key_id, key_json
- """
-
- txn.execute(
- sql,
- (user_id, device_id, algorithm, user_id, device_id, algorithm, count),
- )
- otk_rows = list(txn)
- if not otk_rows:
- return []
-
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
- return [
- (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
- ]
-
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
missing: List[Tuple[str, str, str, int]] = []
- for user_id, device_id, algorithm, count in query_list:
- if self.database_engine.supports_returning:
- # If we support RETURNING clause we can use a single query that
- # allows us to use autocommit mode.
- _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
- db_autocommit = True
- else:
- _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
- db_autocommit = False
-
- claim_rows = await self.db_pool.runInteraction(
+ if isinstance(self.database_engine, PostgresEngine):
+ # If we can use execute_values we can use a single batch query
+ # in autocommit mode.
+ unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {}
+ for user_id, device_id, algorithm, count in query_list:
+ unfulfilled_claim_counts[user_id, device_id, algorithm] = count
+
+ bulk_claims = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
- _claim_e2e_one_time_key,
- user_id,
- device_id,
- algorithm,
- count,
- db_autocommit=db_autocommit,
+ self._claim_e2e_one_time_keys_bulk,
+ query_list,
+ db_autocommit=True,
)
- if claim_rows:
+
+ for user_id, device_id, algorithm, key_id, key_json in bulk_claims:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
- for claim_row in claim_rows:
- device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+ device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
+ unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1
+
# Did we get enough OTKs?
- count -= len(claim_rows)
- if count:
- missing.append((user_id, device_id, algorithm, count))
+ missing = [
+ (user, device, alg, count)
+ for (user, device, alg), count in unfulfilled_claim_counts.items()
+ if count > 0
+ ]
+ else:
+ for user_id, device_id, algorithm, count in query_list:
+ claim_rows = await self.db_pool.runInteraction(
+ "claim_e2e_one_time_keys",
+ self._claim_e2e_one_time_key_simple,
+ user_id,
+ device_id,
+ algorithm,
+ count,
+ db_autocommit=False,
+ )
+ if claim_rows:
+ device_results = results.setdefault(user_id, {}).setdefault(
+ device_id, {}
+ )
+ for claim_row in claim_rows:
+ device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+ # Did we get enough OTKs?
+ count -= len(claim_rows)
+ if count:
+ missing.append((user_id, device_id, algorithm, count))
return results, missing
@@ -1268,6 +1193,65 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
"""
+ if isinstance(self.database_engine, PostgresEngine):
+ return await self.db_pool.runInteraction(
+ "_claim_e2e_fallback_keys_bulk",
+ self._claim_e2e_fallback_keys_bulk_txn,
+ query_list,
+ db_autocommit=True,
+ )
+ # Use an UPDATE FROM... RETURNING combined with a VALUES block to do
+ # everything in one query. Note: this is also supported in SQLite 3.33.0,
+ # (see https://www.sqlite.org/lang_update.html#update_from), but we do not
+ # have an equivalent of psycopg2's execute_values to do this in one query.
+ else:
+ return await self._claim_e2e_fallback_keys_simple(query_list)
+
+ def _claim_e2e_fallback_keys_bulk_txn(
+ self,
+ txn: LoggingTransaction,
+ query_list: Iterable[Tuple[str, str, str, bool]],
+ ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+ """Efficient implementation of claim_e2e_fallback_keys for Postgres.
+
+ Safe to autocommit: this is a single query.
+ """
+ results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+
+ sql = """
+ WITH claims(user_id, device_id, algorithm, mark_as_used) AS (
+ VALUES ?
+ )
+ UPDATE e2e_fallback_keys_json k
+ SET used = used OR mark_as_used
+ FROM claims
+ WHERE (k.user_id, k.device_id, k.algorithm) = (claims.user_id, claims.device_id, claims.algorithm)
+ RETURNING k.user_id, k.device_id, k.algorithm, k.key_id, k.key_json;
+ """
+ claimed_keys = cast(
+ List[Tuple[str, str, str, str, str]],
+ txn.execute_values(sql, query_list),
+ )
+
+ seen_user_device: Set[Tuple[str, str]] = set()
+ for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
+ device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
+ device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
+
+ if (user_id, device_id) in seen_user_device:
+ continue
+ seen_user_device.add((user_id, device_id))
+ self._invalidate_cache_and_stream(
+ txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+ )
+
+ return results
+
+ async def _claim_e2e_fallback_keys_simple(
+ self,
+ query_list: Iterable[Tuple[str, str, str, bool]],
+ ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+ """Naive, inefficient implementation of claim_e2e_fallback_keys for SQLite."""
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, device_id, algorithm, mark_as_used in query_list:
row = await self.db_pool.simple_select_one(
@@ -1310,6 +1294,99 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return results
+ @trace
+ def _claim_e2e_one_time_key_simple(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ algorithm: str,
+ count: int,
+ ) -> List[Tuple[str, str]]:
+ """Claim OTK for device for DBs that don't support RETURNING.
+
+ Returns:
+ A tuple of key name (algorithm + key ID) and key JSON, if an
+ OTK was found.
+ """
+
+ sql = """
+ SELECT key_id, key_json FROM e2e_one_time_keys_json
+ WHERE user_id = ? AND device_id = ? AND algorithm = ?
+ LIMIT ?
+ """
+
+ txn.execute(sql, (user_id, device_id, algorithm, count))
+ otk_rows = list(txn)
+ if not otk_rows:
+ return []
+
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ column="key_id",
+ values=[otk_row[0] for otk_row in otk_rows],
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
+ return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]
+
+ @trace
+ def _claim_e2e_one_time_keys_bulk(
+ self,
+ txn: LoggingTransaction,
+ query_list: Iterable[Tuple[str, str, str, int]],
+ ) -> List[Tuple[str, str, str, str, str]]:
+ """Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.
+
+ Args:
+ query_list: Collection of tuples (user_id, device_id, algorithm, count)
+ as passed to claim_e2e_one_time_keys.
+
+ Returns:
+ A list of tuples (user_id, device_id, algorithm, key_id, key_json)
+ for each OTK claimed.
+ """
+ sql = """
+ WITH claims(user_id, device_id, algorithm, claim_count) AS (
+ VALUES ?
+ ), ranked_keys AS (
+ SELECT
+ user_id, device_id, algorithm, key_id, claim_count,
+ ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
+ FROM e2e_one_time_keys_json
+ JOIN claims USING (user_id, device_id, algorithm)
+ )
+ DELETE FROM e2e_one_time_keys_json k
+ WHERE (user_id, device_id, algorithm, key_id) IN (
+ SELECT user_id, device_id, algorithm, key_id
+ FROM ranked_keys
+ WHERE r <= claim_count
+ )
+ RETURNING user_id, device_id, algorithm, key_id, key_json;
+ """
+ otk_rows = cast(
+ List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
+ )
+
+ seen_user_device: Set[Tuple[str, str]] = set()
+ for user_id, device_id, _, _, _ in otk_rows:
+ if (user_id, device_id) in seen_user_device:
+ continue
+ seen_user_device.add((user_id, device_id))
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
+ return otk_rows
+
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def __init__(
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index d4251be7e7..b8bbd1eccd 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1048,15 +1048,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
Args:
event_ids: The event IDs to calculate the max depth of.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="events",
- column="event_id",
- iterable=event_ids,
- retcols=(
- "event_id",
- "depth",
+ rows = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="events",
+ column="event_id",
+ iterable=event_ids,
+ retcols=(
+ "event_id",
+ "depth",
+ ),
+ desc="get_max_depth_of",
),
- desc="get_max_depth_of",
)
if not rows:
@@ -1064,10 +1067,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
else:
max_depth_event_id = ""
current_max_depth = 0
- for row in rows:
- if row["depth"] > current_max_depth:
- max_depth_event_id = row["event_id"]
- current_max_depth = row["depth"]
+ for event_id, depth in rows:
+ if depth > current_max_depth:
+ max_depth_event_id = event_id
+ current_max_depth = depth
return max_depth_event_id, current_max_depth
@@ -1077,15 +1080,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
Args:
event_ids: The event IDs to calculate the max depth of.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="events",
- column="event_id",
- iterable=event_ids,
- retcols=(
- "event_id",
- "depth",
+ rows = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="events",
+ column="event_id",
+ iterable=event_ids,
+ retcols=(
+ "event_id",
+ "depth",
+ ),
+ desc="get_min_depth_of",
),
- desc="get_min_depth_of",
)
if not rows:
@@ -1093,10 +1099,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
else:
min_depth_event_id = ""
current_min_depth = MAX_DEPTH
- for row in rows:
- if row["depth"] < current_min_depth:
- min_depth_event_id = row["event_id"]
- current_min_depth = row["depth"]
+ for event_id, depth in rows:
+ if depth < current_min_depth:
+ min_depth_event_id = event_id
+ current_min_depth = depth
return min_depth_event_id, current_min_depth
@@ -1552,19 +1558,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
A filtered down list of `event_ids` that have previous failed pull attempts.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="event_failed_pull_attempts",
- column="event_id",
- iterable=event_ids,
- keyvalues={},
- retcols=("event_id",),
- desc="get_event_ids_with_failed_pull_attempts",
+ rows = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="event_failed_pull_attempts",
+ column="event_id",
+ iterable=event_ids,
+ keyvalues={},
+ retcols=("event_id",),
+ desc="get_event_ids_with_failed_pull_attempts",
+ ),
)
- event_ids_with_failed_pull_attempts: Set[str] = {
- row["event_id"] for row in rows
- }
-
- return event_ids_with_failed_pull_attempts
+ return {row[0] for row in rows}
@trace
async def get_event_ids_to_not_pull_from_backoff(
@@ -1584,32 +1589,34 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
A dictionary of event_ids that should not be attempted to be pulled and the
next timestamp at which we may try pulling them again.
"""
- event_failed_pull_attempts = await self.db_pool.simple_select_many_batch(
- table="event_failed_pull_attempts",
- column="event_id",
- iterable=event_ids,
- keyvalues={},
- retcols=(
- "event_id",
- "last_attempt_ts",
- "num_attempts",
+ event_failed_pull_attempts = cast(
+ List[Tuple[str, int, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="event_failed_pull_attempts",
+ column="event_id",
+ iterable=event_ids,
+ keyvalues={},
+ retcols=(
+ "event_id",
+ "last_attempt_ts",
+ "num_attempts",
+ ),
+ desc="get_event_ids_to_not_pull_from_backoff",
),
- desc="get_event_ids_to_not_pull_from_backoff",
)
current_time = self._clock.time_msec()
event_ids_with_backoff = {}
- for event_failed_pull_attempt in event_failed_pull_attempts:
- event_id = event_failed_pull_attempt["event_id"]
+ for event_id, last_attempt_ts, num_attempts in event_failed_pull_attempts:
# Exponential back-off (up to the upper bound) so we don't try to
# pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
backoff_end_time = (
- event_failed_pull_attempt["last_attempt_ts"]
+ last_attempt_ts
+ (
2
** min(
- event_failed_pull_attempt["num_attempts"],
+ num_attempts,
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
)
)
@@ -1890,21 +1897,23 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# keeping only the forward extremities (i.e. the events not referenced
# by other events in the queue). We do this so that we can always
# backpaginate in all the events we have dropped.
- rows = await self.db_pool.simple_select_list(
- table="federation_inbound_events_staging",
- keyvalues={"room_id": room_id},
- retcols=("event_id", "event_json"),
- desc="prune_staged_events_in_room_fetch",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="federation_inbound_events_staging",
+ keyvalues={"room_id": room_id},
+ retcols=("event_id", "event_json"),
+ desc="prune_staged_events_in_room_fetch",
+ ),
)
# Find the set of events referenced by those in the queue, as well as
# collecting all the event IDs in the queue.
referenced_events: Set[str] = set()
seen_events: Set[str] = set()
- for row in rows:
- event_id = row["event_id"]
+ for event_id, event_json in rows:
seen_events.add(event_id)
- event_d = db_to_json(row["event_json"])
+ event_d = db_to_json(event_json)
# We don't bother parsing the dicts into full blown event objects,
# as that is needlessly expensive.
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index ed29d1fa5d..e4dc68c0d8 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -182,6 +182,7 @@ class UserPushAction(EmailPushAction):
profile_tag: str
+# TODO This is used as a cached value and is mutable.
@attr.s(slots=True, auto_attribs=True)
class NotifCounts:
"""
@@ -193,7 +194,7 @@ class NotifCounts:
highlight_count: int = 0
-@attr.s(slots=True, auto_attribs=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class RoomNotifCounts:
"""
The per-user, per-room count of notifications. Used by sync and push.
@@ -201,7 +202,7 @@ class RoomNotifCounts:
main_timeline: NotifCounts
# Map of thread ID to the notification counts.
- threads: Dict[str, NotifCounts]
+ threads: Mapping[str, NotifCounts]
@staticmethod
def empty() -> "RoomNotifCounts":
@@ -483,7 +484,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return room_to_count
- @cached(tree=True, max_entries=5000, iterable=True)
+ @cached(tree=True, max_entries=5000, iterable=True) # type: ignore[synapse-@cached-mutable]
async def get_unread_event_push_actions_by_room_for_user(
self,
room_id: str,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 790d058c43..7c34bde3e5 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -27,6 +27,7 @@ from typing import (
Optional,
Set,
Tuple,
+ Union,
cast,
)
@@ -78,7 +79,7 @@ class DeltaState:
Attributes:
to_delete: List of type/state_keys to delete from current state
to_insert: Map of state to upsert into current state
- no_longer_in_room: The server is not longer in the room, so the room
+ no_longer_in_room: The server is no longer in the room, so the room
should e.g. be removed from `current_state_events` table.
"""
@@ -130,22 +131,25 @@ class PersistEventsStore:
@trace
async def _persist_events_and_state_updates(
self,
+ room_id: str,
events_and_contexts: List[Tuple[EventBase, EventContext]],
*,
- state_delta_for_room: Dict[str, DeltaState],
- new_forward_extremities: Dict[str, Set[str]],
+ state_delta_for_room: Optional[DeltaState],
+ new_forward_extremities: Optional[Set[str]],
use_negative_stream_ordering: bool = False,
inhibit_local_membership_updates: bool = False,
) -> None:
"""Persist a set of events alongside updates to the current state and
- forward extremities tables.
+ forward extremities tables.
+
+ Assumes that we are only persisting events for one room at a time.
Args:
+ room_id:
events_and_contexts:
- state_delta_for_room: Map from room_id to the delta to apply to
- room state
- new_forward_extremities: Map from room_id to set of event IDs
- that are the new forward extremities of the room.
+ state_delta_for_room: The delta to apply to the room state
+ new_forward_extremities: A set of event IDs that are the new forward
+ extremities of the room.
use_negative_stream_ordering: Whether to start stream_ordering on
the negative side and decrement. This should be set as True
for backfilled events because backfilled events get a negative
@@ -195,6 +199,7 @@ class PersistEventsStore:
await self.db_pool.runInteraction(
"persist_events",
self._persist_events_txn,
+ room_id=room_id,
events_and_contexts=events_and_contexts,
inhibit_local_membership_updates=inhibit_local_membership_updates,
state_delta_for_room=state_delta_for_room,
@@ -220,9 +225,9 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc()
- for room_id, latest_event_ids in new_forward_extremities.items():
+ if new_forward_extremities:
self.store.get_latest_event_ids_in_room.prefill(
- (room_id,), frozenset(latest_event_ids)
+ (room_id,), frozenset(new_forward_extremities)
)
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
@@ -335,10 +340,11 @@ class PersistEventsStore:
self,
txn: LoggingTransaction,
*,
+ room_id: str,
events_and_contexts: List[Tuple[EventBase, EventContext]],
inhibit_local_membership_updates: bool,
- state_delta_for_room: Dict[str, DeltaState],
- new_forward_extremities: Dict[str, Set[str]],
+ state_delta_for_room: Optional[DeltaState],
+ new_forward_extremities: Optional[Set[str]],
) -> None:
"""Insert some number of room events into the necessary database tables.
@@ -346,8 +352,11 @@ class PersistEventsStore:
and the rejections table. Things reading from those table will need to check
whether the event was rejected.
+ Assumes that we are only persisting events for one room at a time.
+
Args:
txn
+ room_id: The room the events are from
events_and_contexts: events to persist
inhibit_local_membership_updates: Stop the local_current_membership
from being updated by these events. This should be set to True
@@ -356,10 +365,9 @@ class PersistEventsStore:
delete_existing True to purge existing table rows for the events
from the database. This is useful when retrying due to
IntegrityError.
- state_delta_for_room: The current-state delta for each room.
- new_forward_extremities: The new forward extremities for each room.
- For each room, a list of the event ids which are the forward
- extremities.
+ state_delta_for_room: The current-state delta for the room.
+ new_forward_extremities: The new forward extremities for the room:
+ a set of the event ids which are the forward extremities.
Raises:
PartialStateConflictError: if attempting to persist a partial state event in
@@ -375,14 +383,13 @@ class PersistEventsStore:
#
# Annoyingly SQLite doesn't support row level locking.
if isinstance(self.database_engine, PostgresEngine):
- for room_id in {e.room_id for e, _ in events_and_contexts}:
- txn.execute(
- "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE",
- (room_id,),
- )
- row = txn.fetchone()
- if row is None:
- raise Exception(f"Room does not exist {room_id}")
+ txn.execute(
+ "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE",
+ (room_id,),
+ )
+ row = txn.fetchone()
+ if row is None:
+ raise Exception(f"Room does not exist {room_id}")
# stream orderings should have been assigned by now
assert min_stream_order
@@ -418,7 +425,9 @@ class PersistEventsStore:
events_and_contexts
)
- self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
+ self._update_room_depths_txn(
+ txn, room_id, events_and_contexts=events_and_contexts
+ )
# _update_outliers_txn filters out any events which have already been
# persisted, and returns the filtered list.
@@ -431,11 +440,13 @@ class PersistEventsStore:
self._store_event_txn(txn, events_and_contexts=events_and_contexts)
- self._update_forward_extremities_txn(
- txn,
- new_forward_extremities=new_forward_extremities,
- max_stream_order=max_stream_order,
- )
+ if new_forward_extremities:
+ self._update_forward_extremities_txn(
+ txn,
+ room_id,
+ new_forward_extremities=new_forward_extremities,
+ max_stream_order=max_stream_order,
+ )
self._persist_transaction_ids_txn(txn, events_and_contexts)
@@ -463,7 +474,10 @@ class PersistEventsStore:
# We call this last as it assumes we've inserted the events into
# room_memberships, where applicable.
# NB: This function invalidates all state related caches
- self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+ if state_delta_for_room:
+ self._update_current_state_txn(
+ txn, room_id, state_delta_for_room, min_stream_order
+ )
def _persist_event_auth_chain_txn(
self,
@@ -501,16 +515,19 @@ class PersistEventsStore:
# We ignore legacy rooms that we aren't filling the chain cover index
# for.
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="rooms",
- column="room_id",
- iterable={event.room_id for event in events if event.is_state()},
- keyvalues={},
- retcols=("room_id", "has_auth_chain_index"),
+ rows = cast(
+ List[Tuple[str, Optional[Union[int, bool]]]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="rooms",
+ column="room_id",
+ iterable={event.room_id for event in events if event.is_state()},
+ keyvalues={},
+ retcols=("room_id", "has_auth_chain_index"),
+ ),
)
rooms_using_chain_index = {
- row["room_id"] for row in rows if row["has_auth_chain_index"]
+ room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
}
state_events = {
@@ -571,19 +588,18 @@ class PersistEventsStore:
# We check if there are any events that need to be handled in the rooms
# we're looking at. These should just be out of band memberships, where
# we didn't have the auth chain when we first persisted.
- rows = db_pool.simple_select_many_txn(
- txn,
- table="event_auth_chain_to_calculate",
- keyvalues={},
- column="room_id",
- iterable=set(event_to_room_id.values()),
- retcols=("event_id", "type", "state_key"),
+ auth_chain_to_calc_rows = cast(
+ List[Tuple[str, str, str]],
+ db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth_chain_to_calculate",
+ keyvalues={},
+ column="room_id",
+ iterable=set(event_to_room_id.values()),
+ retcols=("event_id", "type", "state_key"),
+ ),
)
- for row in rows:
- event_id = row["event_id"]
- event_type = row["type"]
- state_key = row["state_key"]
-
+ for event_id, event_type, state_key in auth_chain_to_calc_rows:
# (We could pull out the auth events for all rows at once using
# simple_select_many, but this case happens rarely and almost always
# with a single row.)
@@ -753,23 +769,31 @@ class PersistEventsStore:
# Step 1, fetch all existing links from all the chains we've seen
# referenced.
chain_links = _LinkMap()
- rows = db_pool.simple_select_many_txn(
- txn,
- table="event_auth_chain_links",
- column="origin_chain_id",
- iterable={chain_id for chain_id, _ in chain_map.values()},
- keyvalues={},
- retcols=(
- "origin_chain_id",
- "origin_sequence_number",
- "target_chain_id",
- "target_sequence_number",
+ auth_chain_rows = cast(
+ List[Tuple[int, int, int, int]],
+ db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth_chain_links",
+ column="origin_chain_id",
+ iterable={chain_id for chain_id, _ in chain_map.values()},
+ keyvalues={},
+ retcols=(
+ "origin_chain_id",
+ "origin_sequence_number",
+ "target_chain_id",
+ "target_sequence_number",
+ ),
),
)
- for row in rows:
+ for (
+ origin_chain_id,
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in auth_chain_rows:
chain_links.add_link(
- (row["origin_chain_id"], row["origin_sequence_number"]),
- (row["target_chain_id"], row["target_sequence_number"]),
+ (origin_chain_id, origin_sequence_number),
+ (target_chain_id, target_sequence_number),
new=False,
)
@@ -1015,74 +1039,75 @@ class PersistEventsStore:
await self.db_pool.runInteraction(
"update_current_state",
self._update_current_state_txn,
- state_delta_by_room={room_id: state_delta},
+ room_id,
+ delta_state=state_delta,
stream_id=stream_ordering,
)
def _update_current_state_txn(
self,
txn: LoggingTransaction,
- state_delta_by_room: Dict[str, DeltaState],
+ room_id: str,
+ delta_state: DeltaState,
stream_id: int,
) -> None:
- for room_id, delta_state in state_delta_by_room.items():
- to_delete = delta_state.to_delete
- to_insert = delta_state.to_insert
-
- # Figure out the changes of membership to invalidate the
- # `get_rooms_for_user` cache.
- # We find out which membership events we may have deleted
- # and which we have added, then we invalidate the caches for all
- # those users.
- members_changed = {
- state_key
- for ev_type, state_key in itertools.chain(to_delete, to_insert)
- if ev_type == EventTypes.Member
- }
+ to_delete = delta_state.to_delete
+ to_insert = delta_state.to_insert
+
+ # Figure out the changes of membership to invalidate the
+ # `get_rooms_for_user` cache.
+ # We find out which membership events we may have deleted
+ # and which we have added, then we invalidate the caches for all
+ # those users.
+ members_changed = {
+ state_key
+ for ev_type, state_key in itertools.chain(to_delete, to_insert)
+ if ev_type == EventTypes.Member
+ }
- if delta_state.no_longer_in_room:
- # Server is no longer in the room so we delete the room from
- # current_state_events, being careful we've already updated the
- # rooms.room_version column (which gets populated in a
- # background task).
- self._upsert_room_version_txn(txn, room_id)
+ if delta_state.no_longer_in_room:
+ # Server is no longer in the room so we delete the room from
+ # current_state_events, being careful we've already updated the
+ # rooms.room_version column (which gets populated in a
+ # background task).
+ self._upsert_room_version_txn(txn, room_id)
- # Before deleting we populate the current_state_delta_stream
- # so that async background tasks get told what happened.
- sql = """
+ # Before deleting we populate the current_state_delta_stream
+ # so that async background tasks get told what happened.
+ sql = """
INSERT INTO current_state_delta_stream
(stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
SELECT ?, ?, room_id, type, state_key, null, event_id
FROM current_state_events
WHERE room_id = ?
"""
- txn.execute(sql, (stream_id, self._instance_name, room_id))
+ txn.execute(sql, (stream_id, self._instance_name, room_id))
- # We also want to invalidate the membership caches for users
- # that were in the room.
- users_in_room = self.store.get_users_in_room_txn(txn, room_id)
- members_changed.update(users_in_room)
+ # We also want to invalidate the membership caches for users
+ # that were in the room.
+ users_in_room = self.store.get_users_in_room_txn(txn, room_id)
+ members_changed.update(users_in_room)
- self.db_pool.simple_delete_txn(
- txn,
- table="current_state_events",
- keyvalues={"room_id": room_id},
- )
- else:
- # We're still in the room, so we update the current state as normal.
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="current_state_events",
+ keyvalues={"room_id": room_id},
+ )
+ else:
+ # We're still in the room, so we update the current state as normal.
- # First we add entries to the current_state_delta_stream. We
- # do this before updating the current_state_events table so
- # that we can use it to calculate the `prev_event_id`. (This
- # allows us to not have to pull out the existing state
- # unnecessarily).
- #
- # The stream_id for the update is chosen to be the minimum of the stream_ids
- # for the batch of the events that we are persisting; that means we do not
- # end up in a situation where workers see events before the
- # current_state_delta updates.
- #
- sql = """
+ # First we add entries to the current_state_delta_stream. We
+ # do this before updating the current_state_events table so
+ # that we can use it to calculate the `prev_event_id`. (This
+ # allows us to not have to pull out the existing state
+ # unnecessarily).
+ #
+ # The stream_id for the update is chosen to be the minimum of the stream_ids
+ # for the batch of the events that we are persisting; that means we do not
+ # end up in a situation where workers see events before the
+ # current_state_delta updates.
+ #
+ sql = """
INSERT INTO current_state_delta_stream
(stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
SELECT ?, ?, ?, ?, ?, ?, (
@@ -1090,39 +1115,39 @@ class PersistEventsStore:
WHERE room_id = ? AND type = ? AND state_key = ?
)
"""
- txn.execute_batch(
- sql,
+ txn.execute_batch(
+ sql,
+ (
(
- (
- stream_id,
- self._instance_name,
- room_id,
- etype,
- state_key,
- to_insert.get((etype, state_key)),
- room_id,
- etype,
- state_key,
- )
- for etype, state_key in itertools.chain(to_delete, to_insert)
- ),
- )
- # Now we actually update the current_state_events table
+ stream_id,
+ self._instance_name,
+ room_id,
+ etype,
+ state_key,
+ to_insert.get((etype, state_key)),
+ room_id,
+ etype,
+ state_key,
+ )
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ ),
+ )
+ # Now we actually update the current_state_events table
- txn.execute_batch(
- "DELETE FROM current_state_events"
- " WHERE room_id = ? AND type = ? AND state_key = ?",
- (
- (room_id, etype, state_key)
- for etype, state_key in itertools.chain(to_delete, to_insert)
- ),
- )
+ txn.execute_batch(
+ "DELETE FROM current_state_events"
+ " WHERE room_id = ? AND type = ? AND state_key = ?",
+ (
+ (room_id, etype, state_key)
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ ),
+ )
- # We include the membership in the current state table, hence we do
- # a lookup when we insert. This assumes that all events have already
- # been inserted into room_memberships.
- txn.execute_batch(
- """INSERT INTO current_state_events
+ # We include the membership in the current state table, hence we do
+ # a lookup when we insert. This assumes that all events have already
+ # been inserted into room_memberships.
+ txn.execute_batch(
+ """INSERT INTO current_state_events
(room_id, type, state_key, event_id, membership, event_stream_ordering)
VALUES (
?, ?, ?, ?,
@@ -1130,34 +1155,34 @@ class PersistEventsStore:
(SELECT stream_ordering FROM events WHERE event_id = ?)
)
""",
- [
- (room_id, key[0], key[1], ev_id, ev_id, ev_id)
- for key, ev_id in to_insert.items()
- ],
- )
+ [
+ (room_id, key[0], key[1], ev_id, ev_id, ev_id)
+ for key, ev_id in to_insert.items()
+ ],
+ )
- # We now update `local_current_membership`. We do this regardless
- # of whether we're still in the room or not to handle the case where
- # e.g. we just got banned (where we need to record that fact here).
-
- # Note: Do we really want to delete rows here (that we do not
- # subsequently reinsert below)? While technically correct it means
- # we have no record of the fact the user *was* a member of the
- # room but got, say, state reset out of it.
- if to_delete or to_insert:
- txn.execute_batch(
- "DELETE FROM local_current_membership"
- " WHERE room_id = ? AND user_id = ?",
- (
- (room_id, state_key)
- for etype, state_key in itertools.chain(to_delete, to_insert)
- if etype == EventTypes.Member and self.is_mine_id(state_key)
- ),
- )
+ # We now update `local_current_membership`. We do this regardless
+ # of whether we're still in the room or not to handle the case where
+ # e.g. we just got banned (where we need to record that fact here).
- if to_insert:
- txn.execute_batch(
- """INSERT INTO local_current_membership
+ # Note: Do we really want to delete rows here (that we do not
+ # subsequently reinsert below)? While technically correct it means
+ # we have no record of the fact the user *was* a member of the
+ # room but got, say, state reset out of it.
+ if to_delete or to_insert:
+ txn.execute_batch(
+ "DELETE FROM local_current_membership"
+ " WHERE room_id = ? AND user_id = ?",
+ (
+ (room_id, state_key)
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ if etype == EventTypes.Member and self.is_mine_id(state_key)
+ ),
+ )
+
+ if to_insert:
+ txn.execute_batch(
+ """INSERT INTO local_current_membership
(room_id, user_id, event_id, membership, event_stream_ordering)
VALUES (
?, ?, ?,
@@ -1165,29 +1190,27 @@ class PersistEventsStore:
(SELECT stream_ordering FROM events WHERE event_id = ?)
)
""",
- [
- (room_id, key[1], ev_id, ev_id, ev_id)
- for key, ev_id in to_insert.items()
- if key[0] == EventTypes.Member and self.is_mine_id(key[1])
- ],
- )
-
- txn.call_after(
- self.store._curr_state_delta_stream_cache.entity_has_changed,
- room_id,
- stream_id,
+ [
+ (room_id, key[1], ev_id, ev_id, ev_id)
+ for key, ev_id in to_insert.items()
+ if key[0] == EventTypes.Member and self.is_mine_id(key[1])
+ ],
)
- # Invalidate the various caches
- self.store._invalidate_state_caches_and_stream(
- txn, room_id, members_changed
- )
+ txn.call_after(
+ self.store._curr_state_delta_stream_cache.entity_has_changed,
+ room_id,
+ stream_id,
+ )
- # Check if any of the remote membership changes requires us to
- # unsubscribe from their device lists.
- self.store.handle_potentially_left_users_txn(
- txn, {m for m in members_changed if not self.hs.is_mine_id(m)}
- )
+ # Invalidate the various caches
+ self.store._invalidate_state_caches_and_stream(txn, room_id, members_changed)
+
+ # Check if any of the remote membership changes requires us to
+ # unsubscribe from their device lists.
+ self.store.handle_potentially_left_users_txn(
+ txn, {m for m in members_changed if not self.hs.is_mine_id(m)}
+ )
def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
"""Update the room version in the database based off current state
@@ -1221,23 +1244,19 @@ class PersistEventsStore:
def _update_forward_extremities_txn(
self,
txn: LoggingTransaction,
- new_forward_extremities: Dict[str, Set[str]],
+ room_id: str,
+ new_forward_extremities: Set[str],
max_stream_order: int,
) -> None:
- for room_id in new_forward_extremities.keys():
- self.db_pool.simple_delete_txn(
- txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
- )
+ self.db_pool.simple_delete_txn(
+ txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
+ )
self.db_pool.simple_insert_many_txn(
txn,
table="event_forward_extremities",
keys=("event_id", "room_id"),
- values=[
- (ev_id, room_id)
- for room_id, new_extrem in new_forward_extremities.items()
- for ev_id in new_extrem
- ],
+ values=[(ev_id, room_id) for ev_id in new_forward_extremities],
)
# We now insert into stream_ordering_to_exterm a mapping from room_id,
# new stream_ordering to new forward extremeties in the room.
@@ -1249,8 +1268,7 @@ class PersistEventsStore:
keys=("room_id", "event_id", "stream_ordering"),
values=[
(room_id, event_id, max_stream_order)
- for room_id, new_extrem in new_forward_extremities.items()
- for event_id in new_extrem
+ for event_id in new_forward_extremities
],
)
@@ -1287,36 +1305,45 @@ class PersistEventsStore:
def _update_room_depths_txn(
self,
txn: LoggingTransaction,
+ room_id: str,
events_and_contexts: List[Tuple[EventBase, EventContext]],
) -> None:
"""Update min_depth for each room
Args:
txn: db connection
+ room_id: The room ID
events_and_contexts: events we are persisting
"""
- depth_updates: Dict[str, int] = {}
+ stream_ordering: Optional[int] = None
+ depth_update = 0
for event, context in events_and_contexts:
- # Then update the `stream_ordering` position to mark the latest
- # event as the front of the room. This should not be done for
- # backfilled events because backfilled events have negative
- # stream_ordering and happened in the past so we know that we don't
- # need to update the stream_ordering tip/front for the room.
+ # Don't update the stream ordering for backfilled events because
+ # backfilled events have negative stream_ordering and happened in the
+ # past, so we know that we don't need to update the stream_ordering
+ # tip/front for the room.
assert event.internal_metadata.stream_ordering is not None
if event.internal_metadata.stream_ordering >= 0:
- txn.call_after(
- self.store._events_stream_cache.entity_has_changed,
- event.room_id,
- event.internal_metadata.stream_ordering,
- )
+ if stream_ordering is None:
+ stream_ordering = event.internal_metadata.stream_ordering
+ else:
+ stream_ordering = max(
+ stream_ordering, event.internal_metadata.stream_ordering
+ )
if not event.internal_metadata.is_outlier() and not context.rejected:
- depth_updates[event.room_id] = max(
- event.depth, depth_updates.get(event.room_id, event.depth)
- )
+ depth_update = max(event.depth, depth_update)
- for room_id, depth in depth_updates.items():
- self._update_min_depth_for_room_txn(txn, room_id, depth)
+ # Then update the `stream_ordering` position to mark the latest event as
+ # the front of the room.
+ if stream_ordering is not None:
+ txn.call_after(
+ self.store._events_stream_cache.entity_has_changed,
+ room_id,
+ stream_ordering,
+ )
+
+ self._update_min_depth_for_room_txn(txn, room_id, depth_update)
def _update_outliers_txn(
self,
@@ -1339,13 +1366,19 @@ class PersistEventsStore:
PartialStateConflictError: if attempting to persist a partial state event in
a room that has been un-partial stated.
"""
- txn.execute(
- "SELECT event_id, outlier FROM events WHERE event_id in (%s)"
- % (",".join(["?"] * len(events_and_contexts)),),
- [event.event_id for event, _ in events_and_contexts],
+ rows = cast(
+ List[Tuple[str, bool]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ "events",
+ "event_id",
+ [event.event_id for event, _ in events_and_contexts],
+ keyvalues={},
+ retcols=("event_id", "outlier"),
+ ),
)
- have_persisted = dict(cast(Iterable[Tuple[str, bool]], txn))
+ have_persisted = dict(rows)
logger.debug(
"_update_outliers_txn: events=%s have_persisted=%s",
@@ -1443,7 +1476,7 @@ class PersistEventsStore:
txn,
table="event_json",
keys=("event_id", "room_id", "internal_metadata", "json", "format_version"),
- values=(
+ values=[
(
event.event_id,
event.room_id,
@@ -1452,7 +1485,7 @@ class PersistEventsStore:
event.format_version,
)
for event, _ in events_and_contexts
- ),
+ ],
)
self.db_pool.simple_insert_many_txn(
@@ -1475,7 +1508,7 @@ class PersistEventsStore:
"state_key",
"rejection_reason",
),
- values=(
+ values=[
(
self._instance_name,
event.internal_metadata.stream_ordering,
@@ -1494,7 +1527,7 @@ class PersistEventsStore:
context.rejected,
)
for event, context in events_and_contexts
- ),
+ ],
)
# If we're persisting an unredacted event we go and ensure
@@ -1517,11 +1550,11 @@ class PersistEventsStore:
txn,
table="state_events",
keys=("event_id", "room_id", "type", "state_key"),
- values=(
+ values=[
(event.event_id, event.room_id, event.type, event.state_key)
for event, _ in events_and_contexts
if event.is_state()
- ),
+ ],
)
def _store_rejected_events_txn(
@@ -1654,8 +1687,6 @@ class PersistEventsStore:
) -> None:
to_prefill = []
- rows = []
-
ev_map = {e.event_id: e for e, _ in events_and_contexts}
if not ev_map:
return
@@ -1676,10 +1707,9 @@ class PersistEventsStore:
)
txn.execute(sql + clause, args)
- rows = self.db_pool.cursor_to_dict(txn)
- for row in rows:
- event = ev_map[row["event_id"]]
- if not row["rejects"] and not row["redacts"]:
+ for event_id, redacts, rejects in txn:
+ event = ev_map[event_id]
+ if not rejects and not redacts:
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
async def external_prefill() -> None:
@@ -2259,35 +2289,59 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
- # From the events passed in, add all of the prev events as backwards extremities.
- # Ignore any events that are already backwards extrems or outliers.
- query = (
- "INSERT INTO event_backward_extremities (event_id, room_id)"
- " SELECT ?, ? WHERE NOT EXISTS ("
- " SELECT 1 FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- " )"
- # 1. Don't add an event as a extremity again if we already persisted it
- # as a non-outlier.
- # 2. Don't add an outlier as an extremity if it has no prev_events
- " AND NOT EXISTS ("
- " SELECT 1 FROM events"
- " LEFT JOIN event_edges edge"
- " ON edge.event_id = events.event_id"
- " WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = FALSE OR edge.event_id IS NULL)"
- " )"
+
+ room_id = events[0].room_id
+
+ potential_backwards_extremities = {
+ e_id
+ for ev in events
+ for e_id in ev.prev_event_ids()
+ if not ev.internal_metadata.is_outlier()
+ }
+
+ if not potential_backwards_extremities:
+ return
+
+ existing_events_outliers = self.db_pool.simple_select_many_txn(
+ txn,
+ table="events",
+ column="event_id",
+ iterable=potential_backwards_extremities,
+ keyvalues={"outlier": False},
+ retcols=("event_id",),
)
- txn.execute_batch(
- query,
- [
- (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id)
- for ev in events
- for e_id in ev.prev_event_ids()
- if not ev.internal_metadata.is_outlier()
- ],
+ potential_backwards_extremities.difference_update(
+ e for e, in existing_events_outliers
)
+ if potential_backwards_extremities:
+ self.db_pool.simple_upsert_many_txn(
+ txn,
+ table="event_backward_extremities",
+ key_names=("room_id", "event_id"),
+ key_values=[(room_id, ev) for ev in potential_backwards_extremities],
+ value_names=(),
+ value_values=(),
+ )
+
+ # Record the stream orderings where we have new gaps.
+ gap_events = [
+ (room_id, self._instance_name, ev.internal_metadata.stream_ordering)
+ for ev in events
+ if any(
+ e_id in potential_backwards_extremities
+ for e_id in ev.prev_event_ids()
+ )
+ ]
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="timeline_gaps",
+ keys=("room_id", "instance_name", "stream_ordering"),
+ values=gap_events,
+ )
+
# Delete all these events that we've already fetched and now know that their
# prev events are the new backwards extremeties.
query = (
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index daef3685b0..0061805150 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -369,18 +369,20 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
for chunk in chunks:
- ev_rows = self.db_pool.simple_select_many_txn(
- txn,
- table="event_json",
- column="event_id",
- iterable=chunk,
- retcols=["event_id", "json"],
- keyvalues={},
+ ev_rows = cast(
+ List[Tuple[str, str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="event_json",
+ column="event_id",
+ iterable=chunk,
+ retcols=["event_id", "json"],
+ keyvalues={},
+ ),
)
- for row in ev_rows:
- event_id = row["event_id"]
- event_json = db_to_json(row["json"])
+ for event_id, json in ev_rows:
+ event_json = db_to_json(json)
try:
origin_server_ts = event_json["origin_server_ts"]
except (KeyError, AttributeError):
@@ -563,15 +565,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if deleted:
# We now need to invalidate the caches of these rooms
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="events",
- column="event_id",
- iterable=to_delete,
- keyvalues={},
- retcols=("room_id",),
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="events",
+ column="event_id",
+ iterable=to_delete,
+ keyvalues={},
+ retcols=("room_id",),
+ ),
)
- room_ids = {row["room_id"] for row in rows}
+ room_ids = {row[0] for row in rows}
for room_id in room_ids:
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
@@ -1038,18 +1043,21 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
count = len(rows)
# We also need to fetch the auth events for them.
- auth_events = self.db_pool.simple_select_many_txn(
- txn,
- table="event_auth",
- column="event_id",
- iterable=event_to_room_id,
- keyvalues={},
- retcols=("event_id", "auth_id"),
+ auth_events = cast(
+ List[Tuple[str, str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth",
+ column="event_id",
+ iterable=event_to_room_id,
+ keyvalues={},
+ retcols=("event_id", "auth_id"),
+ ),
)
event_to_auth_chain: Dict[str, List[str]] = {}
- for row in auth_events:
- event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
+ for event_id, auth_id in auth_events:
+ event_to_auth_chain.setdefault(event_id, []).append(auth_id)
# Calculate and persist the chain cover index for this set of events.
#
@@ -1302,12 +1310,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
# indexes on it.
- # We need to pass execute a dummy function to handle the txn's result otherwise
- # it tries to call fetchall() on it and fails because there's no result to fetch.
- await self.db_pool.execute(
+ await self.db_pool.runInteraction(
"background_analyze_new_stream_ordering_column",
- lambda txn: None,
- "ANALYZE events(stream_ordering2)",
+ lambda txn: txn.execute("ANALYZE events(stream_ordering2)"),
)
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
index f851bff604..0ba84b1469 100644
--- a/synapse/storage/databases/main/events_forward_extremities.py
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, List
+from typing import List, Optional, Tuple, cast
from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
@@ -91,12 +91,17 @@ class EventForwardExtremitiesStore(
async def get_forward_extremities_for_room(
self, room_id: str
- ) -> List[Dict[str, Any]]:
- """Get list of forward extremities for a room."""
+ ) -> List[Tuple[str, int, int, Optional[int]]]:
+ """
+ Get list of forward extremities for a room.
+
+ Returns:
+ A list of tuples of event_id, state_group, depth, and received_ts.
+ """
def get_forward_extremities_for_room_txn(
txn: LoggingTransaction,
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[str, int, int, Optional[int]]]:
sql = """
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities
@@ -106,7 +111,7 @@ class EventForwardExtremitiesStore(
"""
txn.execute(sql, (room_id,))
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[Tuple[str, int, int, Optional[int]]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_forward_extremities_for_room",
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index b788d70fc5..5bf864c1fb 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1584,16 +1584,19 @@ class EventsWorkerStore(SQLBaseStore):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="events",
- retcols=("event_id",),
- column="event_id",
- iterable=list(event_ids),
- keyvalues={"outlier": False},
- desc="have_events_in_timeline",
+ rows = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="events",
+ retcols=("event_id",),
+ column="event_id",
+ iterable=list(event_ids),
+ keyvalues={"outlier": False},
+ desc="have_events_in_timeline",
+ ),
)
- return {r["event_id"] for r in rows}
+ return {r[0] for r in rows}
@trace
@tag_args
@@ -2093,12 +2096,6 @@ class EventsWorkerStore(SQLBaseStore):
def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
sql = """
- DELETE FROM event_txn_id
- WHERE inserted_ts < ?
- """
- txn.execute(sql, (one_day_ago,))
-
- sql = """
DELETE FROM event_txn_id_device_id
WHERE inserted_ts < ?
"""
@@ -2336,15 +2333,18 @@ class EventsWorkerStore(SQLBaseStore):
a dict mapping from event id to partial-stateness. We return True for
any of the events which are unknown (or are outliers).
"""
- result = await self.db_pool.simple_select_many_batch(
- table="partial_state_events",
- column="event_id",
- iterable=event_ids,
- retcols=["event_id"],
- desc="get_partial_state_events",
+ result = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="partial_state_events",
+ column="event_id",
+ iterable=event_ids,
+ retcols=["event_id"],
+ desc="get_partial_state_events",
+ ),
)
# convert the result to a dict, to make @cachedList work
- partial = {r["event_id"] for r in result}
+ partial = {r[0] for r in result}
return {e_id: e_id in partial for e_id in event_ids}
@cached()
diff --git a/synapse/storage/databases/main/experimental_features.py b/synapse/storage/databases/main/experimental_features.py
index 654f924019..60621edeef 100644
--- a/synapse/storage/databases/main/experimental_features.py
+++ b/synapse/storage/databases/main/experimental_features.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Dict, FrozenSet
+from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main import CacheInvalidationWorkerStore
@@ -42,13 +42,16 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
Returns:
the features currently enabled for the user
"""
- enabled = await self.db_pool.simple_select_list(
- "per_user_experimental_features",
- {"user_id": user_id, "enabled": True},
- ["feature"],
+ enabled = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_list(
+ table="per_user_experimental_features",
+ keyvalues={"user_id": user_id, "enabled": True},
+ retcols=("feature",),
+ ),
)
- return frozenset(feature["feature"] for feature in enabled)
+ return frozenset(feature[0] for feature in enabled)
async def set_features_for_user(
self,
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 889c578b9c..ce88772f9e 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,7 +16,7 @@
import itertools
import json
import logging
-from typing import Dict, Iterable, Mapping, Optional, Tuple
+from typing import Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
@@ -205,35 +205,39 @@ class KeyStore(CacheInvalidationWorkerStore):
If we have multiple entries for a given key ID, returns the most recent.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="server_keys_json",
- column="key_id",
- iterable=key_ids,
- keyvalues={"server_name": server_name},
- retcols=(
- "key_id",
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "key_json",
+ rows = cast(
+ List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
+ await self.db_pool.simple_select_many_batch(
+ table="server_keys_json",
+ column="key_id",
+ iterable=key_ids,
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
),
- desc="get_server_keys_json_for_remote",
)
if not rows:
return {}
- # We sort the rows so that the most recently added entry is picked up.
- rows.sort(key=lambda r: r["ts_added_ms"])
+ # We sort the rows by ts_added_ms so that the most recently added entry
+ # will stomp over older entries in the dictionary.
+ rows.sort(key=lambda r: r[2])
return {
- row["key_id"]: FetchKeyResultForRemote(
+ key_id: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
- key_json=bytes(row["key_json"]),
- valid_until_ts=row["ts_valid_until_ms"],
- added_ts=row["ts_added_ms"],
+ key_json=bytes(key_json),
+ valid_until_ts=ts_valid_until_ms,
+ added_ts=ts_added_ms,
)
- for row in rows
+ for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
}
async def get_all_server_keys_json_for_remote(
@@ -244,30 +248,35 @@ class KeyStore(CacheInvalidationWorkerStore):
If we have multiple entries for a given key ID, returns the most recent.
"""
- rows = await self.db_pool.simple_select_list(
- table="server_keys_json",
- keyvalues={"server_name": server_name},
- retcols=(
- "key_id",
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "key_json",
+ rows = cast(
+ List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
+ await self.db_pool.simple_select_list(
+ table="server_keys_json",
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
),
- desc="get_server_keys_json_for_remote",
)
if not rows:
return {}
- rows.sort(key=lambda r: r["ts_added_ms"])
+ # We sort the rows by ts_added_ms so that the most recently added entry
+ # will stomp over older entries in the dictionary.
+ rows.sort(key=lambda r: r[2])
return {
- row["key_id"]: FetchKeyResultForRemote(
+ key_id: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
- key_json=bytes(row["key_json"]),
- valid_until_ts=row["ts_valid_until_ms"],
- added_ts=row["ts_added_ms"],
+ key_json=bytes(key_json),
+ valid_until_ts=ts_valid_until_ms,
+ added_ts=ts_added_ms,
)
- for row in rows
+ for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
}
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 8cebeb5189..c8d7c9fd32 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -26,8 +26,11 @@ from typing import (
cast,
)
+import attr
+
from synapse.api.constants import Direction
from synapse.logging.opentracing import trace
+from synapse.media._base import ThumbnailInfo
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -44,6 +47,18 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class LocalMedia:
+ media_id: str
+ media_type: str
+ media_length: int
+ upload_name: str
+ created_ts: int
+ last_access_ts: int
+ quarantined_by: Optional[str]
+ safe_from_quarantine: bool
+
+
class MediaSortOrder(Enum):
"""
Enum to define the sorting method used when returning media with
@@ -179,7 +194,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
user_id: str,
order_by: str = MediaSortOrder.CREATED_TS.value,
direction: Direction = Direction.FORWARDS,
- ) -> Tuple[List[Dict[str, Any]], int]:
+ ) -> Tuple[List[LocalMedia], int]:
"""Get a paginated list of metadata for a local piece of media
which an user_id has uploaded
@@ -196,7 +211,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def get_local_media_by_user_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[Dict[str, Any]], int]:
+ ) -> Tuple[List[LocalMedia], int]:
# Set ordering
order_by_column = MediaSortOrder(order_by).value
@@ -216,14 +231,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
sql = """
SELECT
- "media_id",
- "media_type",
- "media_length",
- "upload_name",
- "created_ts",
- "last_access_ts",
- "quarantined_by",
- "safe_from_quarantine"
+ media_id,
+ media_type,
+ media_length,
+ upload_name,
+ created_ts,
+ last_access_ts,
+ quarantined_by,
+ safe_from_quarantine
FROM local_media_repository
WHERE user_id = ?
ORDER BY {order_by_column} {order}, media_id ASC
@@ -235,7 +250,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
args += [limit, start]
txn.execute(sql, args)
- media = self.db_pool.cursor_to_dict(txn)
+ media = [
+ LocalMedia(
+ media_id=row[0],
+ media_type=row[1],
+ media_length=row[2],
+ upload_name=row[3],
+ created_ts=row[4],
+ last_access_ts=row[5],
+ quarantined_by=row[6],
+ safe_from_quarantine=bool(row[7]),
+ )
+ for row in txn
+ ]
return media, count
return await self.db_pool.runInteraction(
@@ -435,19 +462,28 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_url_cache",
)
- async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]:
- return await self.db_pool.simple_select_list(
- "local_media_repository_thumbnails",
- {"media_id": media_id},
- (
- "thumbnail_width",
- "thumbnail_height",
- "thumbnail_method",
- "thumbnail_type",
- "thumbnail_length",
+ async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]:
+ rows = cast(
+ List[Tuple[int, int, str, str, int]],
+ await self.db_pool.simple_select_list(
+ "local_media_repository_thumbnails",
+ {"media_id": media_id},
+ (
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ ),
+ desc="get_local_media_thumbnails",
),
- desc="get_local_media_thumbnails",
)
+ return [
+ ThumbnailInfo(
+ width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
+ )
+ for row in rows
+ ]
@trace
async def store_local_thumbnail(
@@ -556,20 +592,28 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_thumbnails(
self, origin: str, media_id: str
- ) -> List[Dict[str, Any]]:
- return await self.db_pool.simple_select_list(
- "remote_media_cache_thumbnails",
- {"media_origin": origin, "media_id": media_id},
- (
- "thumbnail_width",
- "thumbnail_height",
- "thumbnail_method",
- "thumbnail_type",
- "thumbnail_length",
- "filesystem_id",
+ ) -> List[ThumbnailInfo]:
+ rows = cast(
+ List[Tuple[int, int, str, str, int]],
+ await self.db_pool.simple_select_list(
+ "remote_media_cache_thumbnails",
+ {"media_origin": origin, "media_id": media_id},
+ (
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ ),
+ desc="get_remote_media_thumbnails",
),
- desc="get_remote_media_thumbnails",
)
+ return [
+ ThumbnailInfo(
+ width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
+ )
+ for row in rows
+ ]
@trace
async def get_remote_media_thumbnail(
@@ -632,7 +676,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_ids(
self, before_ts: int, include_quarantined_media: bool
- ) -> List[Dict[str, str]]:
+ ) -> List[Tuple[str, str, str]]:
"""
Retrieve a list of server name, media ID tuples from the remote media cache.
@@ -646,12 +690,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
A list of tuples containing:
* The server name of homeserver where the media originates from,
* The ID of the media.
+ * The filesystem ID.
+ """
+
+ sql = """
+ SELECT media_origin, media_id, filesystem_id
+ FROM remote_media_cache
+ WHERE last_access_ts < ?
"""
- sql = (
- "SELECT media_origin, media_id, filesystem_id"
- " FROM remote_media_cache"
- " WHERE last_access_ts < ?"
- )
if include_quarantined_media is False:
# Only include media that has not been quarantined
@@ -659,8 +705,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
AND quarantined_by IS NULL
"""
- return await self.db_pool.execute(
- "get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
+ return cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.execute("get_remote_media_ids", sql, before_ts),
)
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 194b4e031f..3b444d2d07 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -20,6 +20,7 @@ from typing import (
Mapping,
Optional,
Tuple,
+ Union,
cast,
)
@@ -260,27 +261,40 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
async def get_presence_for_users(
self, user_ids: Iterable[str]
) -> Mapping[str, UserPresenceState]:
- rows = await self.db_pool.simple_select_many_batch(
- table="presence_stream",
- column="user_id",
- iterable=user_ids,
- keyvalues={},
- retcols=(
- "user_id",
- "state",
- "last_active_ts",
- "last_federation_update_ts",
- "last_user_sync_ts",
- "status_msg",
- "currently_active",
+ # TODO All these columns are nullable, but we don't expect that:
+ # https://github.com/matrix-org/synapse/issues/16467
+ rows = cast(
+ List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
+ await self.db_pool.simple_select_many_batch(
+ table="presence_stream",
+ column="user_id",
+ iterable=user_ids,
+ keyvalues={},
+ retcols=(
+ "user_id",
+ "state",
+ "last_active_ts",
+ "last_federation_update_ts",
+ "last_user_sync_ts",
+ "status_msg",
+ "currently_active",
+ ),
+ desc="get_presence_for_users",
),
- desc="get_presence_for_users",
)
- for row in rows:
- row["currently_active"] = bool(row["currently_active"])
-
- return {row["user_id"]: UserPresenceState(**row) for row in rows}
+ return {
+ user_id: UserPresenceState(
+ user_id=user_id,
+ state=state,
+ last_active_ts=last_active_ts,
+ last_federation_update_ts=last_federation_update_ts,
+ last_user_sync_ts=last_user_sync_ts,
+ status_msg=status_msg,
+ currently_active=bool(currently_active),
+ )
+ for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
+ }
async def should_user_receive_full_presence_with_token(
self,
@@ -385,28 +399,49 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
limit = 100
offset = 0
while True:
- rows = await self.db_pool.runInteraction(
- "get_presence_for_all_users",
- self.db_pool.simple_select_list_paginate_txn,
- "presence_stream",
- orderby="stream_id",
- start=offset,
- limit=limit,
- exclude_keyvalues=exclude_keyvalues,
- retcols=(
- "user_id",
- "state",
- "last_active_ts",
- "last_federation_update_ts",
- "last_user_sync_ts",
- "status_msg",
- "currently_active",
+ # TODO All these columns are nullable, but we don't expect that:
+ # https://github.com/matrix-org/synapse/issues/16467
+ rows = cast(
+ List[Tuple[str, str, int, int, int, Optional[str], Union[int, bool]]],
+ await self.db_pool.runInteraction(
+ "get_presence_for_all_users",
+ self.db_pool.simple_select_list_paginate_txn,
+ "presence_stream",
+ orderby="stream_id",
+ start=offset,
+ limit=limit,
+ exclude_keyvalues=exclude_keyvalues,
+ retcols=(
+ "user_id",
+ "state",
+ "last_active_ts",
+ "last_federation_update_ts",
+ "last_user_sync_ts",
+ "status_msg",
+ "currently_active",
+ ),
+ order_direction="ASC",
),
- order_direction="ASC",
)
- for row in rows:
- users_to_state[row["user_id"]] = UserPresenceState(**row)
+ for (
+ user_id,
+ state,
+ last_active_ts,
+ last_federation_update_ts,
+ last_user_sync_ts,
+ status_msg,
+ currently_active,
+ ) in rows:
+ users_to_state[user_id] = UserPresenceState(
+ user_id=user_id,
+ state=state,
+ last_active_ts=last_active_ts,
+ last_federation_update_ts=last_federation_update_ts,
+ last_user_sync_ts=last_user_sync_ts,
+ status_msg=status_msg,
+ currently_active=bool(currently_active),
+ )
# We've run out of updates to query
if len(rows) < limit:
@@ -434,13 +469,21 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
txn.close()
- for row in rows:
- row["currently_active"] = bool(row["currently_active"])
-
- return [UserPresenceState(**row) for row in rows]
+ return [
+ UserPresenceState(
+ user_id=user_id,
+ state=state,
+ last_active_ts=last_active_ts,
+ last_federation_update_ts=last_federation_update_ts,
+ last_user_sync_ts=last_user_sync_ts,
+ status_msg=status_msg,
+ currently_active=bool(currently_active),
+ )
+ for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
+ ]
def take_presence_startup_info(self) -> List[UserPresenceState]:
active_on_startup = self._presence_on_startup
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index dea0e0458c..1e11bf2706 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -89,6 +89,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# furthermore, we might already have the table from a previous (failed)
# purge attempt, so let's drop the table first.
+ if isinstance(self.database_engine, PostgresEngine):
+ # Disable statement timeouts for this transaction; purging rooms can
+ # take a while!
+ txn.execute("SET LOCAL statement_timeout = 0")
+
txn.execute("DROP TABLE IF EXISTS events_to_purge")
txn.execute(
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 923166974c..37135d431d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,8 +28,11 @@ from typing import (
cast,
)
+from twisted.internet import defer
+
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@@ -51,7 +54,8 @@ from synapse.storage.util.id_generators import (
)
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import json_encoder, unwrapFirstError
+from synapse.util.async_helpers import gather_results
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -62,20 +66,34 @@ logger = logging.getLogger(__name__)
def _load_rules(
- rawrules: List[JsonDict],
+ rawrules: List[Tuple[str, int, str, str]],
enabled_map: Dict[str, bool],
experimental_config: ExperimentalConfig,
) -> FilteredPushRules:
"""Take the DB rows returned from the DB and convert them into a full
`FilteredPushRules` object.
+
+ Args:
+ rawrules: List of tuples of:
+ * rule ID
+ * Priority lass
+ * Conditions (as serialized JSON)
+ * Actions (as serialized JSON)
+ enabled_map: A dictionary of rule ID to a boolean of whether the rule is
+ enabled. This might not include all rule IDs from rawrules.
+ experimental_config: The `experimental_features` section of the Synapse
+ config. (Used to check if various features are enabled.)
+
+ Returns:
+ A new FilteredPushRules object.
"""
ruleslist = [
PushRule.from_db(
- rule_id=rawrule["rule_id"],
- priority_class=rawrule["priority_class"],
- conditions=rawrule["conditions"],
- actions=rawrule["actions"],
+ rule_id=rawrule[0],
+ priority_class=rawrule[1],
+ conditions=rawrule[2],
+ actions=rawrule[3],
)
for rawrule in rawrules
]
@@ -165,34 +183,44 @@ class PushRulesWorkerStore(
@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
- rows = await self.db_pool.simple_select_list(
- table="push_rules",
- keyvalues={"user_name": user_id},
- retcols=(
- "user_name",
- "rule_id",
- "priority_class",
- "priority",
- "conditions",
- "actions",
+ rows = cast(
+ List[Tuple[str, int, int, str, str]],
+ await self.db_pool.simple_select_list(
+ table="push_rules",
+ keyvalues={"user_name": user_id},
+ retcols=(
+ "rule_id",
+ "priority_class",
+ "priority",
+ "conditions",
+ "actions",
+ ),
+ desc="get_push_rules_for_user",
),
- desc="get_push_rules_for_user",
)
- rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
+ # Sort by highest priority_class, then highest priority.
+ rows.sort(key=lambda row: (-int(row[1]), -int(row[2])))
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
- return _load_rules(rows, enabled_map, self.hs.config.experimental)
+ return _load_rules(
+ [(row[0], row[1], row[3], row[4]) for row in rows],
+ enabled_map,
+ self.hs.config.experimental,
+ )
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
- results = await self.db_pool.simple_select_list(
- table="push_rules_enable",
- keyvalues={"user_name": user_id},
- retcols=("rule_id", "enabled"),
- desc="get_push_rules_enabled_for_user",
+ results = cast(
+ List[Tuple[str, Optional[Union[int, bool]]]],
+ await self.db_pool.simple_select_list(
+ table="push_rules_enable",
+ keyvalues={"user_name": user_id},
+ retcols=("rule_id", "enabled"),
+ desc="get_push_rules_enabled_for_user",
+ ),
)
- return {r["rule_id"]: bool(r["enabled"]) for r in results}
+ return {r[0]: bool(r[1]) for r in results}
async def have_push_rules_changed_for_user(
self, user_id: str, last_id: int
@@ -221,23 +249,46 @@ class PushRulesWorkerStore(
if not user_ids:
return {}
- raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
+ raw_rules: Dict[str, List[Tuple[str, int, str, str]]] = {
+ user_id: [] for user_id in user_ids
+ }
- rows = await self.db_pool.simple_select_many_batch(
- table="push_rules",
- column="user_name",
- iterable=user_ids,
- retcols=("*",),
- desc="bulk_get_push_rules",
- batch_size=1000,
+ # gatherResults loses all type information.
+ rows, enabled_map_by_user = await make_deferred_yieldable(
+ gather_results(
+ (
+ cast(
+ "defer.Deferred[List[Tuple[str, str, int, int, str, str]]]",
+ run_in_background(
+ self.db_pool.simple_select_many_batch,
+ table="push_rules",
+ column="user_name",
+ iterable=user_ids,
+ retcols=(
+ "user_name",
+ "rule_id",
+ "priority_class",
+ "priority",
+ "conditions",
+ "actions",
+ ),
+ desc="bulk_get_push_rules",
+ batch_size=1000,
+ ),
+ ),
+ run_in_background(self.bulk_get_push_rules_enabled, user_ids),
+ ),
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
)
- rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
-
- for row in rows:
- raw_rules.setdefault(row["user_name"], []).append(row)
+ # Sort by highest priority_class, then highest priority.
+ rows.sort(key=lambda row: (-int(row[2]), -int(row[3])))
- enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
+ for user_name, rule_id, priority_class, _, conditions, actions in rows:
+ raw_rules.setdefault(user_name, []).append(
+ (rule_id, priority_class, conditions, actions)
+ )
results: Dict[str, FilteredPushRules] = {}
@@ -256,17 +307,19 @@ class PushRulesWorkerStore(
results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
- rows = await self.db_pool.simple_select_many_batch(
- table="push_rules_enable",
- column="user_name",
- iterable=user_ids,
- retcols=("user_name", "rule_id", "enabled"),
- desc="bulk_get_push_rules_enabled",
- batch_size=1000,
+ rows = cast(
+ List[Tuple[str, str, Optional[int]]],
+ await self.db_pool.simple_select_many_batch(
+ table="push_rules_enable",
+ column="user_name",
+ iterable=user_ids,
+ retcols=("user_name", "rule_id", "enabled"),
+ desc="bulk_get_push_rules_enabled",
+ batch_size=1000,
+ ),
)
- for row in rows:
- enabled = bool(row["enabled"])
- results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
+ for user_name, rule_id, enabled in rows:
+ results.setdefault(user_name, {})[rule_id] = bool(enabled)
return results
async def get_all_push_rule_updates(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 87e28e22d3..a6a1671bd6 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -47,6 +47,27 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# The type of a row in the pushers table.
+PusherRow = Tuple[
+ int, # id
+ str, # user_name
+ Optional[int], # access_token
+ str, # profile_tag
+ str, # kind
+ str, # app_id
+ str, # app_display_name
+ str, # device_display_name
+ str, # pushkey
+ int, # ts
+ str, # lang
+ str, # data
+ int, # last_stream_ordering
+ int, # last_success
+ int, # failing_since
+ bool, # enabled
+ str, # device_id
+]
+
class PusherWorkerStore(SQLBaseStore):
def __init__(
@@ -83,30 +104,66 @@ class PusherWorkerStore(SQLBaseStore):
self._remove_deleted_email_pushers,
)
- def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
+ def _decode_pushers_rows(
+ self,
+ rows: Iterable[PusherRow],
+ ) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
Drops any rows whose data cannot be decoded
"""
- for r in rows:
- data_json = r["data"]
+ for (
+ id,
+ user_name,
+ access_token,
+ profile_tag,
+ kind,
+ app_id,
+ app_display_name,
+ device_display_name,
+ pushkey,
+ ts,
+ lang,
+ data,
+ last_stream_ordering,
+ last_success,
+ failing_since,
+ enabled,
+ device_id,
+ ) in rows:
try:
- r["data"] = db_to_json(data_json)
+ data_json = db_to_json(data)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
- r["id"],
- data_json,
+ id,
+ data,
e.args[0],
)
continue
- # If we're using SQLite, then boolean values are integers. This is
- # troublesome since some code using the return value of this method might
- # expect it to be a boolean, or will expose it to clients (in responses).
- r["enabled"] = bool(r["enabled"])
-
- yield PusherConfig(**r)
+ yield PusherConfig(
+ id=id,
+ user_name=user_name,
+ profile_tag=profile_tag,
+ kind=kind,
+ app_id=app_id,
+ app_display_name=app_display_name,
+ device_display_name=device_display_name,
+ pushkey=pushkey,
+ ts=ts,
+ lang=lang,
+ data=data_json,
+ last_stream_ordering=last_stream_ordering,
+ last_success=last_success,
+ failing_since=failing_since,
+ # If we're using SQLite, then boolean values are integers. This is
+ # troublesome since some code using the return value of this method might
+ # expect it to be a boolean, or will expose it to clients (in responses).
+ enabled=bool(enabled),
+ device_id=device_id,
+ access_token=access_token,
+ )
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
@@ -136,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
The pushers for which the given columns have the given values.
"""
- def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def get_pushers_by_txn(txn: LoggingTransaction) -> List[PusherRow]:
# We could technically use simple_select_list here, but we need to call
# COALESCE on the 'enabled' column. While it is technically possible to give
# simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
@@ -154,7 +211,7 @@ class PusherWorkerStore(SQLBaseStore):
txn.execute(sql, list(keyvalues.values()))
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[PusherRow], txn.fetchall())
ret = await self.db_pool.runInteraction(
desc="get_pushers_by",
@@ -164,14 +221,22 @@ class PusherWorkerStore(SQLBaseStore):
return self._decode_pushers_rows(ret)
async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
- def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
- txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
- rows = self.db_pool.cursor_to_dict(txn)
-
- return self._decode_pushers_rows(rows)
+ def get_enabled_pushers_txn(txn: LoggingTransaction) -> List[PusherRow]:
+ txn.execute(
+ """
+ SELECT id, user_name, access_token, profile_tag, kind, app_id,
+ app_display_name, device_display_name, pushkey, ts, lang, data,
+ last_stream_ordering, last_success, failing_since,
+ enabled, device_id
+ FROM pushers WHERE COALESCE(enabled, TRUE)
+ """
+ )
+ return cast(List[PusherRow], txn.fetchall())
- return await self.db_pool.runInteraction(
- "get_enabled_pushers", get_enabled_pushers_txn
+ return self._decode_pushers_rows(
+ await self.db_pool.runInteraction(
+ "get_enabled_pushers", get_enabled_pushers_txn
+ )
)
async def get_all_updated_pushers_rows(
@@ -304,26 +369,28 @@ class PusherWorkerStore(SQLBaseStore):
)
async def get_throttle_params_by_room(
- self, pusher_id: str
+ self, pusher_id: int
) -> Dict[str, ThrottleParams]:
- res = await self.db_pool.simple_select_list(
- "pusher_throttle",
- {"pusher": pusher_id},
- ["room_id", "last_sent_ts", "throttle_ms"],
- desc="get_throttle_params_by_room",
+ res = cast(
+ List[Tuple[str, Optional[int], Optional[int]]],
+ await self.db_pool.simple_select_list(
+ "pusher_throttle",
+ {"pusher": pusher_id},
+ ["room_id", "last_sent_ts", "throttle_ms"],
+ desc="get_throttle_params_by_room",
+ ),
)
params_by_room = {}
- for row in res:
- params_by_room[row["room_id"]] = ThrottleParams(
- row["last_sent_ts"],
- row["throttle_ms"],
+ for room_id, last_sent_ts, throttle_ms in res:
+ params_by_room[room_id] = ThrottleParams(
+ last_sent_ts or 0, throttle_ms or 0
)
return params_by_room
async def set_throttle_params(
- self, pusher_id: str, room_id: str, params: ThrottleParams
+ self, pusher_id: int, room_id: str, params: ThrottleParams
) -> None:
await self.db_pool.simple_upsert(
"pusher_throttle",
@@ -534,7 +601,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
(last_pusher_id, batch_size),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if len(rows) == 0:
return 0
@@ -550,19 +617,19 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
txn=txn,
table="pushers",
key_names=("id",),
- key_values=[(row["pusher_id"],) for row in rows],
+ key_values=[row[0] for row in rows],
value_names=("device_id", "access_token"),
# If there was already a device_id on the pusher, we only want to clear
# the access_token column, so we keep the existing device_id. Otherwise,
# we set the device_id we got from joining the access_tokens table.
value_values=[
- (row["pusher_device_id"] or row["token_device_id"], None)
- for row in rows
+ (pusher_device_id or token_device_id, None)
+ for _, pusher_device_id, token_device_id in rows
],
)
self.db_pool.updates._background_update_progress_txn(
- txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["pusher_id"]}
+ txn, "set_device_id_for_pushers", {"pusher_id": rows[-1][0]}
)
return len(rows)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 0231f9407b..56e8eb16a8 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -28,6 +28,8 @@ from typing import (
cast,
)
+from immutabledict import immutabledict
+
from synapse.api.constants import EduTypes
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -43,7 +45,12 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
StreamIdGenerator,
)
-from synapse.types import JsonDict, JsonMapping
+from synapse.types import (
+ JsonDict,
+ JsonMapping,
+ MultiWriterStreamToken,
+ PersistedPosition,
+)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -105,7 +112,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"receipts_linearized",
entity_column="room_id",
stream_column="stream_id",
- max_value=max_receipts_stream_id,
+ max_value=max_receipts_stream_id.stream,
limit=10000,
)
self._receipts_stream_cache = StreamChangeCache(
@@ -114,9 +121,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
prefilled_cache=receipts_stream_prefill,
)
- def get_max_receipt_stream_id(self) -> int:
+ def get_max_receipt_stream_id(self) -> MultiWriterStreamToken:
"""Get the current max stream ID for receipts stream"""
- return self._receipts_id_gen.get_current_token()
+
+ min_pos = self._receipts_id_gen.get_current_token()
+
+ positions = {}
+ if isinstance(self._receipts_id_gen, MultiWriterIdGenerator):
+ # The `min_pos` is the minimum position that we know all instances
+ # have finished persisting to, so we only care about instances whose
+ # positions are ahead of that. (Instance positions can be behind the
+ # min position as there are times we can work out that the minimum
+ # position is ahead of the naive minimum across all current
+ # positions. See MultiWriterIdGenerator for details)
+ positions = {
+ i: p
+ for i, p in self._receipts_id_gen.get_positions().items()
+ if p > min_pos
+ }
+
+ return MultiWriterStreamToken(
+ stream=min_pos, instance_map=immutabledict(positions)
+ )
+
+ def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
+ return self._receipts_id_gen.get_current_token_for_writer(instance_name)
def get_last_unthreaded_receipt_for_user_txn(
self,
@@ -257,7 +286,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
async def get_linearized_receipts_for_rooms(
- self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
+ self,
+ room_ids: Iterable[str],
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> List[JsonMapping]:
"""Get receipts for multiple rooms for sending to clients.
@@ -276,7 +308,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# Only ask the database about rooms where there have been new
# receipts added since `from_key`
room_ids = self._receipts_stream_cache.get_entities_changed(
- room_ids, from_key
+ room_ids, from_key.stream
)
results = await self._get_linearized_receipts_for_rooms(
@@ -286,7 +318,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return [ev for res in results.values() for ev in res]
async def get_linearized_receipts_for_room(
- self, room_id: str, to_key: int, from_key: Optional[int] = None
+ self,
+ room_id: str,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Sequence[JsonMapping]:
"""Get receipts for a single room for sending to clients.
@@ -302,36 +337,49 @@ class ReceiptsWorkerStore(SQLBaseStore):
if from_key is not None:
# Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op.
- if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
+ if not self._receipts_stream_cache.has_entity_changed(
+ room_id, from_key.stream
+ ):
return []
return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
@cached(tree=True)
async def _get_linearized_receipts_for_room(
- self, room_id: str, to_key: int, from_key: Optional[int] = None
+ self,
+ room_id: str,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Sequence[JsonMapping]:
"""See get_linearized_receipts_for_room"""
- def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
if from_key:
- sql = (
- "SELECT * FROM receipts_linearized WHERE"
- " room_id = ? AND stream_id > ? AND stream_id <= ?"
- )
+ sql = """
+ SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
+ FROM receipts_linearized
+ WHERE room_id = ? AND stream_id > ? AND stream_id <= ?
+ """
- txn.execute(sql, (room_id, from_key, to_key))
- else:
- sql = (
- "SELECT * FROM receipts_linearized WHERE"
- " room_id = ? AND stream_id <= ?"
+ txn.execute(
+ sql, (room_id, from_key.stream, to_key.get_max_stream_pos())
)
+ else:
+ sql = """
+ SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
+ FROM receipts_linearized WHERE
+ room_id = ? AND stream_id <= ?
+ """
- txn.execute(sql, (room_id, to_key))
-
- rows = self.db_pool.cursor_to_dict(txn)
+ txn.execute(sql, (room_id, to_key.get_max_stream_pos()))
- return rows
+ return [
+ (receipt_type, user_id, event_id, data)
+ for stream_id, instance_name, receipt_type, user_id, event_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
@@ -339,10 +387,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return []
content: JsonDict = {}
- for row in rows:
- content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
- row["user_id"]
- ] = db_to_json(row["data"])
+ for receipt_type, user_id, event_id, data in rows:
+ content.setdefault(event_id, {}).setdefault(receipt_type, {})[
+ user_id
+ ] = db_to_json(data)
return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]
@@ -352,25 +400,37 @@ class ReceiptsWorkerStore(SQLBaseStore):
num_args=3,
)
async def _get_linearized_receipts_for_rooms(
- self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
+ self,
+ room_ids: Collection[str],
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Mapping[str, Sequence[JsonMapping]]:
if not room_ids:
return {}
- def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def f(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[str, str, str, str, Optional[str], str]]:
if from_key:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT stream_id, instance_name, room_id, receipt_type,
+ user_id, event_id, thread_id, data
+ FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? AND
"""
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", room_ids
)
- txn.execute(sql + clause, [from_key, to_key] + list(args))
+ txn.execute(
+ sql + clause,
+ [from_key.stream, to_key.get_max_stream_pos()] + list(args),
+ )
else:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT stream_id, instance_name, room_id, receipt_type,
+ user_id, event_id, thread_id, data
+ FROM receipts_linearized WHERE
stream_id <= ? AND
"""
@@ -378,31 +438,37 @@ class ReceiptsWorkerStore(SQLBaseStore):
self.database_engine, "room_id", room_ids
)
- txn.execute(sql + clause, [to_key] + list(args))
+ txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args))
- return self.db_pool.cursor_to_dict(txn)
+ return [
+ (room_id, receipt_type, user_id, event_id, thread_id, data)
+ for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
results: JsonDict = {}
- for row in txn_results:
+ for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
- row["room_id"],
- {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
+ room_id,
+ {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
- event_entry = room_event["content"].setdefault(row["event_id"], {})
- receipt_type = event_entry.setdefault(row["receipt_type"], {})
+ event_entry = room_event["content"].setdefault(event_id, {})
+ receipt_type_dict = event_entry.setdefault(receipt_type, {})
- receipt_type[row["user_id"]] = db_to_json(row["data"])
- if row["thread_id"]:
- receipt_type[row["user_id"]]["thread_id"] = row["thread_id"]
+ receipt_type_dict[user_id] = db_to_json(data)
+ if thread_id:
+ receipt_type_dict[user_id]["thread_id"] = thread_id
results = {
room_id: [results[room_id]] if room_id in results else []
@@ -414,7 +480,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
num_args=2,
)
async def get_linearized_receipts_for_all_rooms(
- self, to_key: int, from_key: Optional[int] = None
+ self,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Mapping[str, JsonMapping]:
"""Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts.
@@ -428,46 +496,54 @@ class ReceiptsWorkerStore(SQLBaseStore):
A dictionary of roomids to a list of receipts.
"""
- def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
if from_key:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
+ FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
"""
- txn.execute(sql, [from_key, to_key])
+ txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()])
else:
sql = """
- SELECT * FROM receipts_linearized WHERE
+ SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
+ FROM receipts_linearized WHERE
stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
"""
- txn.execute(sql, [to_key])
+ txn.execute(sql, [to_key.get_max_stream_pos()])
- return self.db_pool.cursor_to_dict(txn)
+ return [
+ (room_id, receipt_type, user_id, event_id, data)
+ for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
txn_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_all_rooms", f
)
results: JsonDict = {}
- for row in txn_results:
+ for room_id, receipt_type, user_id, event_id, data in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
- row["room_id"],
- {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
+ room_id,
+ {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
- event_entry = room_event["content"].setdefault(row["event_id"], {})
- receipt_type = event_entry.setdefault(row["receipt_type"], {})
+ event_entry = room_event["content"].setdefault(event_id, {})
+ receipt_type_dict = event_entry.setdefault(receipt_type, {})
- receipt_type[row["user_id"]] = db_to_json(row["data"])
+ receipt_type_dict[user_id] = db_to_json(data)
return results
@@ -537,10 +613,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ?
+ AND instance_name = ?
ORDER BY stream_id ASC
LIMIT ?
"""
- txn.execute(sql, (last_id, current_id, limit))
+ txn.execute(sql, (last_id, current_id, instance_name, limit))
updates = cast(
List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
@@ -687,6 +764,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
keyvalues=keyvalues,
values={
"stream_id": stream_id,
+ "instance_name": self._instance_name,
"event_id": event_id,
"event_stream_ordering": stream_ordering,
"data": json_encoder.encode(data),
@@ -742,7 +820,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
event_ids: List[str],
thread_id: Optional[str],
data: dict,
- ) -> Optional[Tuple[int, int]]:
+ ) -> Optional[PersistedPosition]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
@@ -804,9 +882,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
data,
)
- max_persisted_id = self._receipts_id_gen.get_current_token()
-
- return stream_id, max_persisted_id
+ return PersistedPosition(self._instance_name, stream_id)
async def _insert_graph_receipt(
self,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index cc964604e2..933d76e905 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -143,6 +143,30 @@ class LoginTokenLookupResult:
"""The session ID advertised by the SSO Identity Provider."""
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class ThreepidResult:
+ medium: str
+ address: str
+ validated_at: int
+ added_at: int
+
+
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class ThreepidValidationSession:
+ address: str
+ """address of the 3pid"""
+ medium: str
+ """medium of the 3pid"""
+ client_secret: str
+ """a secret provided by the client for this validation session"""
+ session_id: str
+ """ID of the validation session"""
+ last_send_attempt: int
+ """a number serving to dedupe send attempts for this session"""
+ validated_at: Optional[int]
+ """timestamp of when this session was validated if so"""
+
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -195,7 +219,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Returns info about the user account, if it exists."""
- def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
+ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]:
# We could technically use simple_select_one here, but it would not perform
# the COALESCEs (unless hacked into the column names), which could yield
# confusing results.
@@ -213,35 +237,46 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
(user_id,),
)
- rows = self.db_pool.cursor_to_dict(txn)
-
- if len(rows) == 0:
+ row = txn.fetchone()
+ if not row:
return None
- return rows[0]
+ (
+ name,
+ is_guest,
+ admin,
+ consent_version,
+ consent_ts,
+ consent_server_notice_sent,
+ appservice_id,
+ creation_ts,
+ user_type,
+ deactivated,
+ shadow_banned,
+ approved,
+ locked,
+ ) = row
+
+ return UserInfo(
+ appservice_id=appservice_id,
+ consent_server_notice_sent=consent_server_notice_sent,
+ consent_version=consent_version,
+ consent_ts=consent_ts,
+ creation_ts=creation_ts,
+ is_admin=bool(admin),
+ is_deactivated=bool(deactivated),
+ is_guest=bool(is_guest),
+ is_shadow_banned=bool(shadow_banned),
+ user_id=UserID.from_string(name),
+ user_type=user_type,
+ approved=bool(approved),
+ locked=bool(locked),
+ )
- row = await self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
desc="get_user_by_id",
func=get_user_by_id_txn,
)
- if row is None:
- return None
-
- return UserInfo(
- appservice_id=row["appservice_id"],
- consent_server_notice_sent=row["consent_server_notice_sent"],
- consent_version=row["consent_version"],
- consent_ts=row["consent_ts"],
- creation_ts=row["creation_ts"],
- is_admin=bool(row["admin"]),
- is_deactivated=bool(row["deactivated"]),
- is_guest=bool(row["is_guest"]),
- is_shadow_banned=bool(row["shadow_banned"]),
- user_id=UserID.from_string(row["name"]),
- user_type=row["user_type"],
- approved=bool(row["approved"]),
- locked=bool(row["locked"]),
- )
async def is_trial_user(self, user_id: str) -> bool:
"""Checks if user is in the "trial" period, i.e. within the first
@@ -579,16 +614,31 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""
txn.execute(sql, (token,))
- rows = self.db_pool.cursor_to_dict(txn)
-
- if rows:
- row = rows[0]
-
- # This field is nullable, ensure it comes out as a boolean
- if row["token_used"] is None:
- row["token_used"] = False
+ row = txn.fetchone()
- return TokenLookupResult(**row)
+ if row:
+ (
+ user_id,
+ is_guest,
+ shadow_banned,
+ token_id,
+ device_id,
+ valid_until_ms,
+ token_owner,
+ token_used,
+ ) = row
+
+ return TokenLookupResult(
+ user_id=user_id,
+ is_guest=is_guest,
+ shadow_banned=shadow_banned,
+ token_id=token_id,
+ device_id=device_id,
+ valid_until_ms=valid_until_ms,
+ token_owner=token_owner,
+ # This field is nullable, ensure it comes out as a boolean
+ token_used=bool(token_used),
+ )
return None
@@ -821,23 +871,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns:
Tuples of (auth_provider, external_id)
"""
- res = await self.db_pool.simple_select_list(
- table="user_external_ids",
- keyvalues={"user_id": mxid},
- retcols=("auth_provider", "external_id"),
- desc="get_external_ids_by_user",
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="user_external_ids",
+ keyvalues={"user_id": mxid},
+ retcols=("auth_provider", "external_id"),
+ desc="get_external_ids_by_user",
+ ),
)
- return [(r["auth_provider"], r["external_id"]) for r in res]
async def count_all_users(self) -> int:
"""Counts all users registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int:
- txn.execute("SELECT COUNT(*) AS users FROM users")
- rows = self.db_pool.cursor_to_dict(txn)
- if rows:
- return rows[0]["users"]
- return 0
+ txn.execute("SELECT COUNT(*) FROM users")
+ row = txn.fetchone()
+ assert row is not None
+ return row[0]
return await self.db_pool.runInteraction("count_users", _count_users)
@@ -891,11 +942,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Counts all users without a special user_type registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int:
- txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
- rows = self.db_pool.cursor_to_dict(txn)
- if rows:
- return rows[0]["users"]
- return 0
+ txn.execute("SELECT COUNT(*) FROM users where user_type is null")
+ row = txn.fetchone()
+ assert row is not None
+ return row[0]
return await self.db_pool.runInteraction("count_real_users", _count_users)
@@ -964,13 +1014,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
- async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]:
- return await self.db_pool.simple_select_list(
- "user_threepids",
- {"user_id": user_id},
- ["medium", "address", "validated_at", "added_at"],
- "user_get_threepids",
+ async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
+ results = cast(
+ List[Tuple[str, str, int, int]],
+ await self.db_pool.simple_select_list(
+ "user_threepids",
+ keyvalues={"user_id": user_id},
+ retcols=["medium", "address", "validated_at", "added_at"],
+ desc="user_get_threepids",
+ ),
)
+ return [
+ ThreepidResult(
+ medium=r[0],
+ address=r[1],
+ validated_at=r[2],
+ added_at=r[3],
+ )
+ for r in results
+ ]
async def user_delete_threepid(
self, user_id: str, medium: str, address: str
@@ -1009,7 +1071,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="add_user_bound_threepid",
)
- async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
+ async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]:
"""Get the threepids that a user has bound to an identity server through the homeserver
The homeserver remembers where binds to an identity server occurred. Using this
method can retrieve those threepids.
@@ -1018,15 +1080,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_id: The ID of the user to retrieve threepids for
Returns:
- List of dictionaries containing the following keys:
- medium (str): The medium of the threepid (e.g "email")
- address (str): The address of the threepid (e.g "bob@example.com")
- """
- return await self.db_pool.simple_select_list(
- table="user_threepid_id_server",
- keyvalues={"user_id": user_id},
- retcols=["medium", "address"],
- desc="user_get_bound_threepids",
+ List of tuples of two strings:
+ medium: The medium of the threepid (e.g "email")
+ address: The address of the threepid (e.g "bob@example.com")
+ """
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="user_threepid_id_server",
+ keyvalues={"user_id": user_id},
+ retcols=["medium", "address"],
+ desc="user_get_bound_threepids",
+ ),
)
async def remove_user_bound_threepid(
@@ -1123,7 +1188,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
address: Optional[str] = None,
sid: Optional[str] = None,
validated: Optional[bool] = True,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[ThreepidValidationSession]:
"""Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata
@@ -1138,15 +1203,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
perform no filtering
Returns:
- A dict containing the following:
- * address - address of the 3pid
- * medium - medium of the 3pid
- * client_secret - a secret provided by the client for this validation session
- * session_id - ID of the validation session
- * send_attempt - a number serving to dedupe send attempts for this session
- * validated_at - timestamp of when this session was validated if so
-
- Otherwise None if a validation session is not found
+ A ThreepidValidationSession or None if a validation session is not found
"""
if not client_secret:
raise SynapseError(
@@ -1165,7 +1222,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def get_threepid_validation_session_txn(
txn: LoggingTransaction,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[ThreepidValidationSession]:
sql = """
SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at
@@ -1180,11 +1237,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
sql += " LIMIT 1"
txn.execute(sql, list(keyvalues.values()))
- rows = self.db_pool.cursor_to_dict(txn)
- if not rows:
+ row = txn.fetchone()
+ if not row:
return None
- return rows[0]
+ return ThreepidValidationSession(
+ address=row[0],
+ session_id=row[1],
+ medium=row[2],
+ client_secret=row[3],
+ last_send_attempt=row[4],
+ validated_at=row[5],
+ )
return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
@@ -1252,12 +1316,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
txn.execute(sql, [])
- res = self.db_pool.cursor_to_dict(txn)
- if res:
- for user in res:
- self.set_expiration_date_for_user_txn(
- txn, user["name"], use_delta=True
- )
+ for (name,) in txn.fetchall():
+ self.set_expiration_date_for_user_txn(txn, name, use_delta=True)
await self.db_pool.runInteraction(
"get_users_with_no_expiration_date",
@@ -1457,7 +1517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def get_registration_tokens(
self, valid: Optional[bool] = None
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]:
"""List all registration tokens. Used by the admin API.
Args:
@@ -1466,34 +1526,48 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Default is None: return all tokens regardless of validity.
Returns:
- A list of dicts, each containing details of a token.
+ A list of tuples containing:
+ * The token
+ * The number of users allowed (or None)
+ * Whether it is pending
+ * Whether it has been completed
+ * An expiry time (or None if no expiry)
"""
def select_registration_tokens_txn(
txn: LoggingTransaction, now: int, valid: Optional[bool]
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]:
if valid is None:
# Return all tokens regardless of validity
- txn.execute("SELECT * FROM registration_tokens")
+ txn.execute(
+ """
+ SELECT token, uses_allowed, pending, completed, expiry_time
+ FROM registration_tokens
+ """
+ )
elif valid:
# Select valid tokens only
- sql = (
- "SELECT * FROM registration_tokens WHERE "
- "(uses_allowed > pending + completed OR uses_allowed IS NULL) "
- "AND (expiry_time > ? OR expiry_time IS NULL)"
- )
+ sql = """
+ SELECT token, uses_allowed, pending, completed, expiry_time
+ FROM registration_tokens
+ WHERE (uses_allowed > pending + completed OR uses_allowed IS NULL)
+ AND (expiry_time > ? OR expiry_time IS NULL)
+ """
txn.execute(sql, [now])
else:
# Select invalid tokens only
- sql = (
- "SELECT * FROM registration_tokens WHERE "
- "uses_allowed <= pending + completed OR expiry_time <= ?"
- )
+ sql = """
+ SELECT token, uses_allowed, pending, completed, expiry_time
+ FROM registration_tokens
+ WHERE uses_allowed <= pending + completed OR expiry_time <= ?
+ """
txn.execute(sql, [now])
- return self.db_pool.cursor_to_dict(txn)
+ return cast(
+ List[Tuple[str, Optional[int], int, int, Optional[int]]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"select_registration_tokens",
@@ -1963,11 +2037,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
(user_id,),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ row = txn.fetchone()
+ assert row is not None
# We cast to bool because the value returned by the database engine might
# be an integer if we're using SQLite.
- return bool(rows[0]["approved"])
+ return bool(row[0])
return await self.db_pool.runInteraction(
desc="is_user_pending_approval",
@@ -2045,22 +2120,22 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
(last_user, batch_size),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return True, 0
rows_processed_nb = 0
- for user in rows:
- if not user["count_tokens"] and not user["count_threepids"]:
- self.set_user_deactivated_status_txn(txn, user["name"], True)
+ for name, count_tokens, count_threepids in rows:
+ if not count_tokens and not count_threepids:
+ self.set_user_deactivated_status_txn(txn, name, True)
rows_processed_nb += 1
logger.info("Marked %d rows as deactivated", rows_processed_nb)
self.db_pool.updates._background_update_progress_txn(
- txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
+ txn, "users_set_deactivated_flag", {"user_id": rows[-1][0]}
)
if batch_size > len(rows):
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index c04d45bdb5..d0bc78b2e3 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -48,6 +48,8 @@ from synapse.storage.databases.main.stream import (
)
from synapse.storage.engines import PostgresEngine, Psycopg2Engine
from synapse.types import JsonDict, StreamKeyType, StreamToken
+from synapse.storage.engines import PostgresEngine
+from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -314,7 +316,7 @@ class RelationsWorkerStore(SQLBaseStore):
room_key=next_key,
presence_key=0,
typing_key=0,
- receipt_key=0,
+ receipt_key=MultiWriterStreamToken(stream=0),
account_data_key=0,
push_rules_key=0,
to_device_key=0,
@@ -349,16 +351,19 @@ class RelationsWorkerStore(SQLBaseStore):
def get_all_relation_ids_for_event_with_types_txn(
txn: LoggingTransaction,
) -> List[str]:
- rows = self.db_pool.simple_select_many_txn(
- txn=txn,
- table="event_relations",
- column="relation_type",
- iterable=relation_types,
- keyvalues={"relates_to_id": event_id},
- retcols=["event_id"],
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn=txn,
+ table="event_relations",
+ column="relation_type",
+ iterable=relation_types,
+ keyvalues={"relates_to_id": event_id},
+ retcols=["event_id"],
+ ),
)
- return [row["event_id"] for row in rows]
+ return [row[0] for row in rows]
return await self.db_pool.runInteraction(
desc="get_all_relation_ids_for_event_with_types",
@@ -381,14 +386,17 @@ class RelationsWorkerStore(SQLBaseStore):
def get_all_relation_ids_for_event_txn(
txn: LoggingTransaction,
) -> List[str]:
- rows = self.db_pool.simple_select_list_txn(
- txn=txn,
- table="event_relations",
- keyvalues={"relates_to_id": event_id},
- retcols=["event_id"],
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_list_txn(
+ txn=txn,
+ table="event_relations",
+ keyvalues={"relates_to_id": event_id},
+ retcols=["event_id"],
+ ),
)
- return [row["event_id"] for row in rows]
+ return [row[0] for row in rows]
return await self.db_pool.runInteraction(
desc="get_all_relation_ids_for_event",
@@ -458,7 +466,7 @@ class RelationsWorkerStore(SQLBaseStore):
)
return result is not None
- @cached()
+ @cached() # type: ignore[synapse-@cached-mutable]
async def get_references_for_event(self, event_id: str) -> List[JsonDict]:
raise NotImplementedError()
@@ -512,11 +520,12 @@ class RelationsWorkerStore(SQLBaseStore):
"_get_references_for_events_txn", _get_references_for_events_txn
)
- @cached()
+ @cached() # type: ignore[synapse-@cached-mutable]
def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
raise NotImplementedError()
- @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
+ # TODO: This returns a mutable object, which is generally bad.
+ @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") # type: ignore[synapse-@cached-mutable]
async def get_applicable_edits(
self, event_ids: Collection[str]
) -> Mapping[str, Optional[EventBase]]:
@@ -598,11 +607,12 @@ class RelationsWorkerStore(SQLBaseStore):
for original_event_id in event_ids
}
- @cached()
+ @cached() # type: ignore[synapse-@cached-mutable]
def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]:
raise NotImplementedError()
- @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
+ # TODO: This returns a mutable object, which is generally bad.
+ @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") # type: ignore[synapse-@cached-mutable]
async def get_thread_summaries(
self, event_ids: Collection[str]
) -> Mapping[str, Optional[Tuple[int, EventBase]]]:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 719e11aea6..afb880532e 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -78,6 +78,31 @@ class RatelimitOverride:
burst_count: int
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class LargestRoomStats:
+ room_id: str
+ name: Optional[str]
+ canonical_alias: Optional[str]
+ joined_members: int
+ join_rules: Optional[str]
+ guest_access: Optional[str]
+ history_visibility: Optional[str]
+ state_events: int
+ avatar: Optional[str]
+ topic: Optional[str]
+ room_type: Optional[str]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RoomStats(LargestRoomStats):
+ joined_local_members: int
+ version: Optional[str]
+ creator: Optional[str]
+ encryption: Optional[str]
+ federatable: bool
+ public: bool
+
+
class RoomSortOrder(Enum):
"""
Enum to define the sorting method used when returning rooms with get_rooms_paginate
@@ -204,7 +229,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
allow_none=True,
)
- async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
+ async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
"""Retrieve room with statistics.
Args:
@@ -215,7 +240,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def get_room_with_stats_txn(
txn: LoggingTransaction, room_id: str
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[RoomStats]:
sql = """
SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
@@ -229,15 +254,28 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
WHERE room_id = ?
"""
txn.execute(sql, [room_id])
- # Catch error if sql returns empty result to return "None" instead of an error
- try:
- res = self.db_pool.cursor_to_dict(txn)[0]
- except IndexError:
+ row = txn.fetchone()
+ if not row:
return None
-
- res["federatable"] = bool(res["federatable"])
- res["public"] = bool(res["public"])
- return res
+ return RoomStats(
+ room_id=row[0],
+ name=row[1],
+ canonical_alias=row[2],
+ joined_members=row[3],
+ joined_local_members=row[4],
+ version=row[5],
+ creator=row[6],
+ encryption=row[7],
+ federatable=bool(row[8]),
+ public=bool(row[9]),
+ join_rules=row[10],
+ guest_access=row[11],
+ history_visibility=row[12],
+ state_events=row[13],
+ avatar=row[14],
+ topic=row[15],
+ room_type=row[16],
+ )
return await self.db_pool.runInteraction(
"get_room_with_stats", get_room_with_stats_txn, room_id
@@ -368,7 +406,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
bounds: Optional[Tuple[int, str]],
forwards: bool,
ignore_non_federatable: bool = False,
- ) -> List[Dict[str, Any]]:
+ ) -> List[LargestRoomStats]:
"""Gets the largest public rooms (where largest is in terms of joined
members, as tracked in the statistics table).
@@ -505,20 +543,34 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _get_largest_public_rooms_txn(
txn: LoggingTransaction,
- ) -> List[Dict[str, Any]]:
+ ) -> List[LargestRoomStats]:
txn.execute(sql, query_args)
- results = self.db_pool.cursor_to_dict(txn)
+ results = [
+ LargestRoomStats(
+ room_id=r[0],
+ name=r[1],
+ canonical_alias=r[3],
+ joined_members=r[4],
+ join_rules=r[8],
+ guest_access=r[7],
+ history_visibility=r[6],
+ state_events=0,
+ avatar=r[5],
+ topic=r[2],
+ room_type=r[9],
+ )
+ for r in txn
+ ]
if not forwards:
results.reverse()
return results
- ret_val = await self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_largest_public_rooms", _get_largest_public_rooms_txn
)
- return ret_val
@cached(max_entries=10000)
async def is_room_blocked(self, room_id: str) -> Optional[bool]:
@@ -831,7 +883,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def get_retention_policy_for_room_txn(
txn: LoggingTransaction,
- ) -> List[Dict[str, Optional[int]]]:
+ ) -> Optional[Tuple[Optional[int], Optional[int]]]:
txn.execute(
"""
SELECT min_lifetime, max_lifetime FROM room_retention
@@ -841,7 +893,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
(room_id,),
)
- return self.db_pool.cursor_to_dict(txn)
+ return cast(Optional[Tuple[Optional[int], Optional[int]]], txn.fetchone())
ret = await self.db_pool.runInteraction(
"get_retention_policy_for_room",
@@ -856,8 +908,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
max_lifetime=self.config.retention.retention_default_max_lifetime,
)
- min_lifetime = ret[0]["min_lifetime"]
- max_lifetime = ret[0]["max_lifetime"]
+ min_lifetime, max_lifetime = ret
# If one of the room's policy's attributes isn't defined, use the matching
# attribute from the default policy.
@@ -1162,14 +1213,13 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql, args)
- rows = self.db_pool.cursor_to_dict(txn)
- rooms_dict = {}
-
- for row in rows:
- rooms_dict[row["room_id"]] = RetentionPolicy(
- min_lifetime=row["min_lifetime"],
- max_lifetime=row["max_lifetime"],
+ rooms_dict = {
+ room_id: RetentionPolicy(
+ min_lifetime=min_lifetime,
+ max_lifetime=max_lifetime,
)
+ for room_id, min_lifetime, max_lifetime in txn
+ }
if include_null:
# If required, do a second query that retrieves all of the rooms we know
@@ -1178,13 +1228,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql)
- rows = self.db_pool.cursor_to_dict(txn)
-
# If a room isn't already in the dict (i.e. it doesn't have a retention
# policy in its state), add it with a null policy.
- for row in rows:
- if row["room_id"] not in rooms_dict:
- rooms_dict[row["room_id"]] = RetentionPolicy()
+ for (room_id,) in txn:
+ if room_id not in rooms_dict:
+ rooms_dict[room_id] = RetentionPolicy()
return rooms_dict
@@ -1236,28 +1284,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
"""
room_servers: Dict[str, PartialStateResyncInfo] = {}
- rows = await self.db_pool.simple_select_list(
- table="partial_state_rooms",
- keyvalues={},
- retcols=("room_id", "joined_via"),
- desc="get_server_which_served_partial_join",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="partial_state_rooms",
+ keyvalues={},
+ retcols=("room_id", "joined_via"),
+ desc="get_server_which_served_partial_join",
+ ),
)
- for row in rows:
- room_id = row["room_id"]
- joined_via = row["joined_via"]
+ for room_id, joined_via in rows:
room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via)
- rows = await self.db_pool.simple_select_list(
- "partial_state_rooms_servers",
- keyvalues=None,
- retcols=("room_id", "server_name"),
- desc="get_partial_state_rooms",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ "partial_state_rooms_servers",
+ keyvalues=None,
+ retcols=("room_id", "server_name"),
+ desc="get_partial_state_rooms",
+ ),
)
- for row in rows:
- room_id = row["room_id"]
- server_name = row["server_name"]
+ for room_id, server_name in rows:
entry = room_servers.get(room_id)
if entry is None:
# There is a foreign key constraint which enforces that every room_id in
@@ -1300,14 +1350,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
complete.
"""
- rows: List[Dict[str, str]] = await self.db_pool.simple_select_many_batch(
- table="partial_state_rooms",
- column="room_id",
- iterable=room_ids,
- retcols=("room_id",),
- desc="is_partial_state_room_batched",
- )
- partial_state_rooms = {row_dict["room_id"] for row_dict in rows}
+ rows = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="partial_state_rooms",
+ column="room_id",
+ iterable=room_ids,
+ retcols=("room_id",),
+ desc="is_partial_state_room_batched",
+ ),
+ )
+ partial_state_rooms = {row[0] for row in rows}
return {room_id: room_id in partial_state_rooms for room_id in room_ids}
async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
@@ -1703,24 +1756,24 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
(last_room, batch_size),
)
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return True
- for row in rows:
- if not row["json"]:
+ for room_id, event_id, json in rows:
+ if not json:
retention_policy = {}
else:
- ev = db_to_json(row["json"])
+ ev = db_to_json(json)
retention_policy = ev["content"]
self.db_pool.simple_insert_txn(
txn=txn,
table="room_retention",
values={
- "room_id": row["room_id"],
- "event_id": row["event_id"],
+ "room_id": room_id,
+ "event_id": event_id,
"min_lifetime": retention_policy.get("min_lifetime"),
"max_lifetime": retention_policy.get("max_lifetime"),
},
@@ -1729,7 +1782,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
logger.info("Inserted %d rows into room_retention", len(rows))
self.db_pool.updates._background_update_progress_txn(
- txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
+ txn, "insert_room_retention", {"room_id": rows[-1][0]}
)
if batch_size > len(rows):
@@ -2215,7 +2268,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
txn,
table="partial_state_rooms_servers",
keys=("room_id", "server_name"),
- values=((room_id, s) for s in servers),
+ values=[(room_id, s) for s in servers],
)
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
self._invalidate_cache_and_stream(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 3755773faa..1ed7f2d0ef 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -27,6 +27,7 @@ from typing import (
Set,
Tuple,
Union,
+ cast,
)
import attr
@@ -275,7 +276,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
_get_users_in_room_with_profiles,
)
- @cached(max_entries=100000)
+ @cached(max_entries=100000) # type: ignore[synapse-@cached-mutable]
async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
"""Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
@@ -481,6 +482,22 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
desc="get_local_users_in_room",
)
+ async def get_local_users_related_to_room(
+ self, room_id: str
+ ) -> List[Tuple[str, str]]:
+ """
+ Retrieves a list of the current roommembers who are local to the server and their membership status.
+ """
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="local_current_membership",
+ keyvalues={"room_id": room_id},
+ retcols=("user_id", "membership"),
+ desc="get_local_users_in_room",
+ ),
+ )
+
async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool:
"""
Check whether a given local user is currently joined to the given room.
@@ -683,25 +700,28 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
Map from user_id to set of rooms that is currently in.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="current_state_events",
- column="state_key",
- iterable=user_ids,
- retcols=(
- "state_key",
- "room_id",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="current_state_events",
+ column="state_key",
+ iterable=user_ids,
+ retcols=(
+ "state_key",
+ "room_id",
+ ),
+ keyvalues={
+ "type": EventTypes.Member,
+ "membership": Membership.JOIN,
+ },
+ desc="get_rooms_for_users",
),
- keyvalues={
- "type": EventTypes.Member,
- "membership": Membership.JOIN,
- },
- desc="get_rooms_for_users",
)
user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
- for row in rows:
- user_rooms[row["state_key"]].add(row["room_id"])
+ for state_key, room_id in rows:
+ user_rooms[state_key].add(room_id)
return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
@@ -892,17 +912,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
Map from event ID to `user_id`, or None if event is not a join.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="room_memberships",
- column="event_id",
- iterable=event_ids,
- retcols=("user_id", "event_id"),
- keyvalues={"membership": Membership.JOIN},
- batch_size=1000,
- desc="_get_user_ids_from_membership_event_ids",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=event_ids,
+ retcols=("event_id", "user_id"),
+ keyvalues={"membership": Membership.JOIN},
+ batch_size=1000,
+ desc="_get_user_ids_from_membership_event_ids",
+ ),
)
- return {row["event_id"]: row["user_id"] for row in rows}
+ return dict(rows)
@cached(max_entries=10000)
async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -933,7 +956,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
like_clause = "%:" + host
rows = await self.db_pool.execute(
- "is_host_joined", None, sql, membership, room_id, like_clause
+ "is_host_joined", sql, membership, room_id, like_clause
)
if not rows:
@@ -1063,15 +1086,19 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
for fully-joined rooms.
"""
- rows = await self.db_pool.simple_select_list(
- "current_state_events",
- keyvalues={"room_id": room_id},
- retcols=("event_id", "membership"),
- desc="has_completed_background_updates",
+ rows = cast(
+ List[Tuple[str, Optional[str]]],
+ await self.db_pool.simple_select_list(
+ "current_state_events",
+ keyvalues={"room_id": room_id},
+ retcols=("event_id", "membership"),
+ desc="has_completed_background_updates",
+ ),
)
- return {row["event_id"]: row["membership"] for row in rows}
+ return dict(rows)
- @cached(max_entries=10000)
+ # TODO This returns a mutable object, which is generally confusing when using a cache.
+ @cached(max_entries=10000) # type: ignore[synapse-@cached-mutable]
def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
return _JoinedHostsCache()
@@ -1157,7 +1184,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
AND forgotten = 0;
"""
- rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+ rows = await self.db_pool.execute("is_forgotten_room", sql, room_id)
# `count(*)` returns always an integer
# If any rows still exist it means someone has not forgotten this room yet
@@ -1201,21 +1228,22 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
membership event, otherwise the value is None.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="room_memberships",
- column="event_id",
- iterable=member_event_ids,
- retcols=("user_id", "membership", "event_id"),
- keyvalues={},
- batch_size=500,
- desc="get_membership_from_event_ids",
+ rows = cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=member_event_ids,
+ retcols=("user_id", "membership", "event_id"),
+ keyvalues={},
+ batch_size=500,
+ desc="get_membership_from_event_ids",
+ ),
)
return {
- row["event_id"]: EventIdMembership(
- membership=row["membership"], user_id=row["user_id"]
- )
- for row in rows
+ event_id: EventIdMembership(membership=membership, user_id=user_id)
+ for user_id, membership, event_id in rows
}
async def is_local_host_in_room_ignoring_users(
@@ -1348,18 +1376,16 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return 0
- min_stream_id = rows[-1]["stream_ordering"]
+ min_stream_id = rows[-1][0]
to_update = []
- for row in rows:
- event_id = row["event_id"]
- room_id = row["room_id"]
+ for _, event_id, room_id, json in rows:
try:
- event_json = db_to_json(row["json"])
+ event_json = db_to_json(json)
content = event_json["content"]
except Exception:
continue
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index d45d2ecc98..4c4112e3b2 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -26,6 +26,7 @@ from typing import (
Set,
Tuple,
Union,
+ cast,
)
import attr
@@ -105,7 +106,7 @@ class SearchWorkerStore(SQLBaseStore):
txn,
table="event_search",
keys=("event_id", "room_id", "key", "value"),
- values=(
+ values=[
(
entry.event_id,
entry.room_id,
@@ -113,7 +114,7 @@ class SearchWorkerStore(SQLBaseStore):
_clean_value_for_search(entry.value),
)
for entry in entries
- ),
+ ],
)
else:
@@ -179,22 +180,24 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# store_search_entries_txn with a generator function, but that
# would mean having two cursors open on the database at once.
# Instead we just build a list of results.
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
if not rows:
return 0
- min_stream_id = rows[-1]["stream_ordering"]
+ min_stream_id = rows[-1][0]
event_search_rows = []
- for row in rows:
+ for (
+ stream_ordering,
+ event_id,
+ room_id,
+ etype,
+ json,
+ origin_server_ts,
+ ) in rows:
try:
- event_id = row["event_id"]
- room_id = row["room_id"]
- etype = row["type"]
- stream_ordering = row["stream_ordering"]
- origin_server_ts = row["origin_server_ts"]
try:
- event_json = db_to_json(row["json"])
+ event_json = db_to_json(json)
content = event_json["content"]
except Exception:
continue
@@ -504,16 +507,18 @@ class SearchStore(SearchBackgroundUpdateStore):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
- results = await self.db_pool.execute(
- "search_msgs", self.db_pool.cursor_to_dict, sql, *args
+ # List of tuples of (rank, room_id, event_id).
+ results = cast(
+ List[Tuple[Union[int, float], str, str]],
+ await self.db_pool.execute("search_msgs", sql, *args),
)
- results = list(filter(lambda row: row["room_id"] in room_ids, results))
+ results = list(filter(lambda row: row[1] in room_ids, results))
# We set redact_behaviour to block here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list( # type: ignore[attr-defined]
- [r["event_id"] for r in results],
+ [r[2] for r in results],
redact_behaviour=EventRedactBehaviour.block,
)
@@ -525,16 +530,18 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = await self.db_pool.execute(
- "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
+ # List of tuples of (room_id, count).
+ count_results = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
)
- count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+ count = sum(row[1] for row in count_results if row[0] in room_ids)
return {
"results": [
- {"event": event_map[r["event_id"]], "rank": r["rank"]}
+ {"event": event_map[r[2]], "rank": r[0]}
for r in results
- if r["event_id"] in event_map
+ if r[2] in event_map
],
"highlights": highlights,
"count": count,
@@ -602,7 +609,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = search_term
sql = """
SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
- origin_server_ts, stream_ordering, room_id, event_id
+ room_id, event_id, origin_server_ts, stream_ordering
FROM event_search
WHERE vector @@ websearch_to_tsquery('english', ?) AND
"""
@@ -663,16 +670,18 @@ class SearchStore(SearchBackgroundUpdateStore):
# mypy expects to append only a `str`, not an `int`
args.append(limit)
- results = await self.db_pool.execute(
- "search_rooms", self.db_pool.cursor_to_dict, sql, *args
+ # List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering).
+ results = cast(
+ List[Tuple[Union[int, float], str, str, int, int]],
+ await self.db_pool.execute("search_rooms", sql, *args),
)
- results = list(filter(lambda row: row["room_id"] in room_ids, results))
+ results = list(filter(lambda row: row[1] in room_ids, results))
# We set redact_behaviour to block here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list( # type: ignore[attr-defined]
- [r["event_id"] for r in results],
+ [r[2] for r in results],
redact_behaviour=EventRedactBehaviour.block,
)
@@ -684,22 +693,23 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = await self.db_pool.execute(
- "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
+ # List of tuples of (room_id, count).
+ count_results = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
)
- count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+ count = sum(row[1] for row in count_results if row[0] in room_ids)
return {
"results": [
{
- "event": event_map[r["event_id"]],
- "rank": r["rank"],
- "pagination_token": "%s,%s"
- % (r["origin_server_ts"], r["stream_ordering"]),
+ "event": event_map[r[2]],
+ "rank": r[0],
+ "pagination_token": "%s,%s" % (r[3], r[4]),
}
for r in results
- if r["event_id"] in event_map
+ if r[2] in event_map
],
"highlights": highlights,
"count": count,
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5eaaff5b68..598025dd91 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -20,10 +20,12 @@ from typing import (
Collection,
Dict,
Iterable,
+ List,
Mapping,
Optional,
Set,
Tuple,
+ cast,
)
import attr
@@ -388,16 +390,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Raises:
RuntimeError if the state is unknown at any of the given events
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="event_to_state_groups",
- column="event_id",
- iterable=event_ids,
- keyvalues={},
- retcols=("event_id", "state_group"),
- desc="_get_state_group_for_events",
+ rows = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="event_to_state_groups",
+ column="event_id",
+ iterable=event_ids,
+ keyvalues={},
+ retcols=("event_id", "state_group"),
+ desc="_get_state_group_for_events",
+ ),
)
- res = {row["event_id"]: row["state_group"] for row in rows}
+ res = dict(rows)
for e in event_ids:
if e not in res:
raise RuntimeError("No state group for unknown or outlier event %s" % e)
@@ -415,16 +420,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
The subset of state groups that are referenced.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="event_to_state_groups",
- column="state_group",
- iterable=state_groups,
- keyvalues={},
- retcols=("DISTINCT state_group",),
- desc="get_referenced_state_groups",
+ rows = cast(
+ List[Tuple[int]],
+ await self.db_pool.simple_select_many_batch(
+ table="event_to_state_groups",
+ column="state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("DISTINCT state_group",),
+ desc="get_referenced_state_groups",
+ ),
)
- return {row["state_group"] for row in rows}
+ return {row[0] for row in rows}
async def update_state_for_partial_state_event(
self,
@@ -624,16 +632,22 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
# potentially stale, since there may have been a period where the
# server didn't share a room with the remote user and therefore may
# have missed any device updates.
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="current_state_events",
- column="room_id",
- iterable=to_delete,
- keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
- retcols=("state_key",),
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="current_state_events",
+ column="room_id",
+ iterable=to_delete,
+ keyvalues={
+ "type": EventTypes.Member,
+ "membership": Membership.JOIN,
+ },
+ retcols=("state_key",),
+ ),
)
- potentially_left_users = {row["state_key"] for row in rows}
+ potentially_left_users = {row[0] for row in rows}
# Now lets actually delete the rooms from the DB.
self.db_pool.simple_delete_many_txn(
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 445213e12a..3151186e0c 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -13,7 +13,9 @@
# limitations under the License.
import logging
-from typing import Any, Dict, List, Tuple
+from typing import List, Optional, Tuple
+
+import attr
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
@@ -22,6 +24,20 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class StateDelta:
+ stream_id: int
+ room_id: str
+ event_type: str
+ state_key: str
+
+ event_id: Optional[str]
+ """new event_id for this state key. None if the state has been deleted."""
+
+ prev_event_id: Optional[str]
+ """previous event_id for this state key. None if it's new state."""
+
+
class StateDeltasStore(SQLBaseStore):
# This class must be mixed in with a child class which provides the following
# attribute. TODO: can we get static analysis to enforce this?
@@ -29,31 +45,21 @@ class StateDeltasStore(SQLBaseStore):
async def get_partial_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
- ) -> Tuple[int, List[Dict[str, Any]]]:
+ ) -> Tuple[int, List[StateDelta]]:
"""Fetch a list of room state changes since the given stream id
- Each entry in the result contains the following fields:
- - stream_id (int)
- - room_id (str)
- - type (str): event type
- - state_key (str):
- - event_id (str|None): new event_id for this state key. None if the
- state has been deleted.
- - prev_event_id (str|None): previous event_id for this state key. None
- if it's new state.
-
This may be the partial state if we're lazy joining the room.
Args:
prev_stream_id: point to get changes since (exclusive)
max_stream_id: the point that we know has been correctly persisted
- - ie, an upper limit to return changes from.
+ - ie, an upper limit to return changes from.
Returns:
A tuple consisting of:
- - the stream id which these results go up to
- - list of current_state_delta_stream rows. If it is empty, we are
- up to date.
+ - the stream id which these results go up to
+ - list of current_state_delta_stream rows. If it is empty, we are
+ up to date.
"""
prev_stream_id = int(prev_stream_id)
@@ -72,7 +78,7 @@ class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas_txn(
txn: LoggingTransaction,
- ) -> Tuple[int, List[Dict[str, Any]]]:
+ ) -> Tuple[int, List[StateDelta]]:
# First we calculate the max stream id that will give us less than
# N results.
# We arbitrarily limit to 100 stream_id entries to ensure we don't
@@ -112,7 +118,17 @@ class StateDeltasStore(SQLBaseStore):
ORDER BY stream_id ASC
"""
txn.execute(sql, (prev_stream_id, clipped_stream_id))
- return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
+ return clipped_stream_id, [
+ StateDelta(
+ stream_id=row[0],
+ room_id=row[1],
+ event_type=row[2],
+ state_key=row[3],
+ event_id=row[4],
+ prev_event_id=row[5],
+ )
+ for row in txn.fetchall()
+ ]
return await self.db_pool.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 9d403919e4..e96c9b0486 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -506,25 +506,28 @@ class StatsStore(StateDeltasStore):
) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined]
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="current_state_events",
- column="type",
- iterable=[
- EventTypes.Create,
- EventTypes.JoinRules,
- EventTypes.RoomHistoryVisibility,
- EventTypes.RoomEncryption,
- EventTypes.Name,
- EventTypes.Topic,
- EventTypes.RoomAvatar,
- EventTypes.CanonicalAlias,
- ],
- keyvalues={"room_id": room_id, "state_key": ""},
- retcols=["event_id"],
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="current_state_events",
+ column="type",
+ iterable=[
+ EventTypes.Create,
+ EventTypes.JoinRules,
+ EventTypes.RoomHistoryVisibility,
+ EventTypes.RoomEncryption,
+ EventTypes.Name,
+ EventTypes.Topic,
+ EventTypes.RoomAvatar,
+ EventTypes.CanonicalAlias,
+ ],
+ keyvalues={"room_id": room_id, "state_key": ""},
+ retcols=["event_id"],
+ ),
)
- event_ids = cast(List[str], [row["event_id"] for row in rows])
+ event_ids = [row[0] for row in rows]
txn.execute(
"""
@@ -676,7 +679,7 @@ class StatsStore(StateDeltasStore):
order_by: Optional[str] = UserSortOrder.USER_ID.value,
direction: Direction = Direction.FORWARDS,
search_term: Optional[str] = None,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
"""Function to retrieve a paginated list of users and their uploaded local media
(size and number). This will return a json list of users and the
total number of users matching the filter criteria.
@@ -689,14 +692,19 @@ class StatsStore(StateDeltasStore):
order_by: the sort order of the returned list
direction: sort ascending or descending
search_term: a string to filter user names by
+
Returns:
- A list of user dicts and an integer representing the total number of
- users that exist given this query
+ A tuple of:
+ A list of tuples of user information (the user ID, displayname,
+ total number of media, total length of media) and
+
+ An integer representing the total number of users that exist
+ given this query
"""
def get_users_media_usage_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
filters = []
args: list = []
@@ -770,7 +778,7 @@ class StatsStore(StateDeltasStore):
args += [limit, start]
txn.execute(sql, args)
- users = self.db_pool.cursor_to_dict(txn)
+ users = cast(List[Tuple[str, Optional[str], int, int]], txn.fetchall())
return users, count
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 5a3611c415..2225f8272d 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -266,7 +266,7 @@ def generate_next_token(
# when we are going backwards so we subtract one from the
# stream part.
last_stream_ordering -= 1
- return RoomStreamToken(last_topo_ordering, last_stream_ordering)
+ return RoomStreamToken(topological=last_topo_ordering, stream=last_stream_ordering)
def _make_generic_sql_bound(
@@ -558,7 +558,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if p > min_pos
}
- return RoomStreamToken(None, min_pos, immutabledict(positions))
+ return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))
async def get_room_events_stream_for_rooms(
self,
@@ -708,7 +708,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse()
if rows:
- key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
+ key = RoomStreamToken(stream=min(r.stream_ordering for r in rows))
else:
# Assume we didn't get anything because there was nothing to
# get.
@@ -969,7 +969,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
topo = await self.db_pool.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
- return RoomStreamToken(topo, stream_ordering)
+ return RoomStreamToken(topological=topo, stream=stream_ordering)
@overload
def get_stream_id_for_event_txn(
@@ -1033,7 +1033,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
)
- return RoomStreamToken(row["topological_ordering"], row["stream_ordering"])
+ return RoomStreamToken(
+ topological=row["topological_ordering"], stream=row["stream_ordering"]
+ )
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
"""Gets the topological token in a room after or at the given stream
@@ -1076,7 +1078,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
row = await self.db_pool.execute(
- "get_current_topological_token", None, sql, room_id, room_id, stream_key
+ "get_current_topological_token", sql, room_id, room_id, stream_key
)
return row[0][0] if row else 0
@@ -1114,8 +1116,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
else:
topo = None
internal = event.internal_metadata
- internal.before = RoomStreamToken(topo, stream - 1)
- internal.after = RoomStreamToken(topo, stream)
+ internal.before = RoomStreamToken(topological=topo, stream=stream - 1)
+ internal.after = RoomStreamToken(topological=topo, stream=stream)
internal.order = (int(topo) if topo else 0, int(stream))
async def get_events_around(
@@ -1191,11 +1193,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
- results["topological_ordering"] - 1, results["stream_ordering"]
+ topological=results["topological_ordering"] - 1,
+ stream=results["stream_ordering"],
)
after_token = RoomStreamToken(
- results["topological_ordering"], results["stream_ordering"]
+ topological=results["topological_ordering"],
+ stream=results["stream_ordering"],
)
rows, start_token = self._paginate_room_events_txn(
@@ -1612,3 +1616,49 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcol="instance_name",
desc="get_name_from_instance_id",
)
+
+ async def get_timeline_gaps(
+ self,
+ room_id: str,
+ from_token: Optional[RoomStreamToken],
+ to_token: RoomStreamToken,
+ ) -> Optional[RoomStreamToken]:
+ """Check if there is a gap, and return a token that marks the position
+ of the gap in the stream.
+ """
+
+ sql = """
+ SELECT instance_name, stream_ordering
+ FROM timeline_gaps
+ WHERE room_id = ? AND ? < stream_ordering AND stream_ordering <= ?
+ ORDER BY stream_ordering
+ """
+
+ rows = await self.db_pool.execute(
+ "get_timeline_gaps",
+ sql,
+ room_id,
+ from_token.stream if from_token else 0,
+ to_token.get_max_stream_pos(),
+ )
+
+ if not rows:
+ return None
+
+ positions = [
+ PersistedEventPosition(instance_name, stream_ordering)
+ for instance_name, stream_ordering in rows
+ ]
+ if from_token:
+ positions = [p for p in positions if p.persisted_after(from_token)]
+
+ positions = [p for p in positions if not p.persisted_after(to_token)]
+
+ if positions:
+ # We return a stream token that ensures the event *at* the position
+ # of the gap is included (as the gap is *before* the persisted
+ # event).
+ last_position = positions[-1]
+ return RoomStreamToken(stream=last_position.stream - 1)
+
+ return None
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 61403a98cf..7deda7790e 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -45,14 +45,17 @@ class TagsWorkerStore(AccountDataWorkerStore):
tag content.
"""
- rows = await self.db_pool.simple_select_list(
- "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
+ rows = cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.simple_select_list(
+ "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
+ ),
)
tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
- for row in rows:
- room_tags = tags_by_room.setdefault(row["room_id"], {})
- room_tags[row["tag"]] = db_to_json(row["content"])
+ for room_id, tag, content in rows:
+ room_tags = tags_by_room.setdefault(room_id, {})
+ room_tags[tag] = db_to_json(content)
return tags_by_room
async def get_all_updated_tags(
@@ -161,13 +164,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
Returns:
A mapping of tags to tag content.
"""
- rows = await self.db_pool.simple_select_list(
- table="room_tags",
- keyvalues={"user_id": user_id, "room_id": room_id},
- retcols=("tag", "content"),
- desc="get_tags_for_room",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="room_tags",
+ keyvalues={"user_id": user_id, "room_id": room_id},
+ retcols=("tag", "content"),
+ desc="get_tags_for_room",
+ ),
)
- return {row["tag"]: db_to_json(row["content"]) for row in rows}
+ return {tag: db_to_json(content) for tag, content in rows}
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py
index 5c5372a825..5555b53575 100644
--- a/synapse/storage/databases/main/task_scheduler.py
+++ b/synapse/storage/databases/main/task_scheduler.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, List, Optional, Tuple, cast
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@@ -27,6 +27,8 @@ from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
+ScheduledTaskRow = Tuple[str, str, str, int, str, str, str, str]
+
class TaskSchedulerWorkerStore(SQLBaseStore):
def __init__(
@@ -38,13 +40,18 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
super().__init__(database, db_conn, hs)
@staticmethod
- def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask:
- row["status"] = TaskStatus(row["status"])
- if row["params"] is not None:
- row["params"] = db_to_json(row["params"])
- if row["result"] is not None:
- row["result"] = db_to_json(row["result"])
- return ScheduledTask(**row)
+ def _convert_row_to_task(row: ScheduledTaskRow) -> ScheduledTask:
+ task_id, action, status, timestamp, resource_id, params, result, error = row
+ return ScheduledTask(
+ id=task_id,
+ action=action,
+ status=TaskStatus(status),
+ timestamp=timestamp,
+ resource_id=resource_id,
+ params=db_to_json(params) if params is not None else None,
+ result=db_to_json(result) if result is not None else None,
+ error=error,
+ )
async def get_scheduled_tasks(
self,
@@ -68,7 +75,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
Returns: a list of `ScheduledTask`, ordered by increasing timestamps
"""
- def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[ScheduledTaskRow]:
clauses: List[str] = []
args: List[Any] = []
if resource_id:
@@ -101,7 +108,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
args.append(limit)
txn.execute(sql, args)
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[ScheduledTaskRow], txn.fetchall())
rows = await self.db_pool.runInteraction(
"get_scheduled_tasks", get_scheduled_tasks_txn
@@ -193,7 +200,22 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
desc="get_scheduled_task",
)
- return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
+ return (
+ TaskSchedulerWorkerStore._convert_row_to_task(
+ (
+ row["id"],
+ row["action"],
+ row["status"],
+ row["timestamp"],
+ row["resource_id"],
+ row["params"],
+ row["result"],
+ row["error"],
+ )
+ )
+ if row
+ else None
+ )
async def delete_scheduled_task(self, id: str) -> None:
"""Delete a specific task from its id.
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 8f70eff809..fecddb4144 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -211,18 +211,28 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
async def get_destination_retry_timings_batch(
self, destinations: StrCollection
) -> Mapping[str, Optional[DestinationRetryTimings]]:
- rows = await self.db_pool.simple_select_many_batch(
- table="destinations",
- iterable=destinations,
- column="destination",
- retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
- desc="get_destination_retry_timings_batch",
+ rows = cast(
+ List[Tuple[str, Optional[int], Optional[int], Optional[int]]],
+ await self.db_pool.simple_select_many_batch(
+ table="destinations",
+ iterable=destinations,
+ column="destination",
+ retcols=(
+ "destination",
+ "failure_ts",
+ "retry_last_ts",
+ "retry_interval",
+ ),
+ desc="get_destination_retry_timings_batch",
+ ),
)
return {
- row.pop("destination"): DestinationRetryTimings(**row)
- for row in rows
- if row["retry_last_ts"] and row["failure_ts"] and row["retry_interval"]
+ destination: DestinationRetryTimings(
+ failure_ts, retry_last_ts, retry_interval
+ )
+ for destination, failure_ts, retry_last_ts, retry_interval in rows
+ if retry_last_ts and failure_ts and retry_interval
}
async def set_destination_retry_timings(
@@ -468,7 +478,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
destination: Optional[str] = None,
order_by: str = DestinationSortOrder.DESTINATION.value,
direction: Direction = Direction.FORWARDS,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[
+ List[Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]],
+ int,
+ ]:
"""Function to retrieve a paginated list of destinations.
This will return a json list of destinations and the
total number of destinations matching the filter criteria.
@@ -480,13 +493,23 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
order_by: the sort order of the returned list
direction: sort ascending or descending
Returns:
- A tuple of a list of mappings from destination to information
+ A tuple of a list of tuples of destination information:
+ * destination
+ * retry_last_ts
+ * retry_interval
+ * failure_ts
+ * last_successful_stream_ordering
and a count of total destinations.
"""
def get_destinations_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[
+ List[
+ Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]
+ ],
+ int,
+ ]:
order_by_column = DestinationSortOrder(order_by).value
if direction == Direction.BACKWARDS:
@@ -513,7 +536,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
LIMIT ? OFFSET ?
"""
txn.execute(sql, args + [limit, start])
- destinations = self.db_pool.cursor_to_dict(txn)
+ destinations = cast(
+ List[
+ Tuple[
+ str, Optional[int], Optional[int], Optional[int], Optional[int]
+ ]
+ ],
+ txn.fetchall(),
+ )
return destinations, count
return await self.db_pool.runInteraction(
@@ -526,7 +556,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
start: int,
limit: int,
direction: Direction = Direction.FORWARDS,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[Tuple[str, int]], int]:
"""Function to retrieve a paginated list of destination's rooms.
This will return a json list of rooms and the
total number of rooms.
@@ -537,12 +567,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
limit: number of rows to retrieve
direction: sort ascending or descending by room_id
Returns:
- A tuple of a dict of rooms and a count of total rooms.
+ A tuple of a list of room tuples and a count of total rooms.
+
+ Each room tuple is room_id, stream_ordering.
"""
def get_destination_rooms_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[Tuple[str, int]], int]:
if direction == Direction.BACKWARDS:
order = "DESC"
else:
@@ -556,14 +588,17 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql, [destination])
count = cast(Tuple[int], txn.fetchone())[0]
- rooms = self.db_pool.simple_select_list_paginate_txn(
- txn=txn,
- table="destination_rooms",
- orderby="room_id",
- start=start,
- limit=limit,
- retcols=("room_id", "stream_ordering"),
- order_direction=order,
+ rooms = cast(
+ List[Tuple[str, int]],
+ self.db_pool.simple_select_list_paginate_txn(
+ txn=txn,
+ table="destination_rooms",
+ orderby="room_id",
+ start=start,
+ limit=limit,
+ retcols=("room_id", "stream_ordering"),
+ order_direction=order,
+ ),
)
return rooms, count
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index f38bedbbcd..8ab7c42c4a 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -169,13 +169,17 @@ class UIAuthWorkerStore(SQLBaseStore):
that auth-type.
"""
results = {}
- for row in await self.db_pool.simple_select_list(
- table="ui_auth_sessions_credentials",
- keyvalues={"session_id": session_id},
- retcols=("stage_type", "result"),
- desc="get_completed_ui_auth_stages",
- ):
- results[row["stage_type"]] = db_to_json(row["result"])
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_credentials",
+ keyvalues={"session_id": session_id},
+ retcols=("stage_type", "result"),
+ desc="get_completed_ui_auth_stages",
+ ),
+ )
+ for stage_type, result in rows:
+ results[stage_type] = db_to_json(result)
return results
@@ -295,13 +299,15 @@ class UIAuthWorkerStore(SQLBaseStore):
Returns:
List of user_agent/ip pairs
"""
- rows = await self.db_pool.simple_select_list(
- table="ui_auth_sessions_ips",
- keyvalues={"session_id": session_id},
- retcols=("user_agent", "ip"),
- desc="get_user_agents_ips_to_ui_auth_session",
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id},
+ retcols=("user_agent", "ip"),
+ desc="get_user_agents_ips_to_ui_auth_session",
+ ),
)
- return [(row["user_agent"], row["ip"]) for row in rows]
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
@@ -337,13 +343,16 @@ class UIAuthWorkerStore(SQLBaseStore):
# If a registration token was used, decrement the pending counter
# before deleting the session.
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="ui_auth_sessions_credentials",
- column="session_id",
- iterable=session_ids,
- keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
- retcols=["result"],
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="ui_auth_sessions_credentials",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
+ retcols=["result"],
+ ),
)
# Get the tokens used and how much pending needs to be decremented by.
@@ -353,23 +362,25 @@ class UIAuthWorkerStore(SQLBaseStore):
# registration token stage for that session will be True.
# If a token was used to authenticate, but registration was
# never completed, the result will be the token used.
- token = db_to_json(r["result"])
+ token = db_to_json(r[0])
if isinstance(token, str):
token_counts[token] = token_counts.get(token, 0) + 1
# Update the `pending` counters.
if len(token_counts) > 0:
- token_rows = self.db_pool.simple_select_many_txn(
- txn,
- table="registration_tokens",
- column="token",
- iterable=list(token_counts.keys()),
- keyvalues={},
- retcols=["token", "pending"],
+ token_rows = cast(
+ List[Tuple[str, int]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="registration_tokens",
+ column="token",
+ iterable=list(token_counts.keys()),
+ keyvalues={},
+ retcols=["token", "pending"],
+ ),
)
- for token_row in token_rows:
- token = token_row["token"]
- new_pending = token_row["pending"] - token_counts[token]
+ for token, pending in token_rows:
+ new_pending = pending - token_counts[token]
self.db_pool.simple_update_one_txn(
txn,
table="registration_tokens",
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index ed41e52201..d4b86ed7a6 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -415,25 +415,24 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
# Next fetch their profiles. Note that not all users have profiles.
- profile_rows = self.db_pool.simple_select_many_txn(
- txn,
- table="profiles",
- column="full_user_id",
- iterable=list(users_to_insert),
- retcols=(
- "full_user_id",
- "displayname",
- "avatar_url",
+ profile_rows = cast(
+ List[Tuple[str, Optional[str], Optional[str]]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="profiles",
+ column="full_user_id",
+ iterable=list(users_to_insert),
+ retcols=(
+ "full_user_id",
+ "displayname",
+ "avatar_url",
+ ),
+ keyvalues={},
),
- keyvalues={},
)
profiles = {
- row["full_user_id"]: _UserDirProfile(
- row["full_user_id"],
- row["displayname"],
- row["avatar_url"],
- )
- for row in profile_rows
+ full_user_id: _UserDirProfile(full_user_id, displayname, avatar_url)
+ for full_user_id, displayname, avatar_url in profile_rows
}
profiles_to_insert = [
@@ -522,18 +521,21 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
and not self.get_if_app_services_interested_in_user(user) # type: ignore[attr-defined]
]
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="users",
- column="name",
- iterable=users,
- keyvalues={
- "deactivated": 0,
- },
- retcols=("name", "user_type"),
+ rows = cast(
+ List[Tuple[str, Optional[str]]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="users",
+ column="name",
+ iterable=users,
+ keyvalues={
+ "deactivated": 0,
+ },
+ retcols=("name", "user_type"),
+ ),
)
- return [row["name"] for row in rows if row["user_type"] != UserTypes.SUPPORT]
+ return [name for name, user_type in rows if user_type != UserTypes.SUPPORT]
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
"""Check if the room is either world_readable or publically joinable"""
@@ -1178,15 +1180,19 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
raise Exception("Unrecognized database engine")
results = cast(
- List[UserProfile],
- await self.db_pool.execute(
- "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
- ),
+ List[Tuple[str, Optional[str], Optional[str]]],
+ await self.db_pool.execute("search_user_dir", sql, *args),
)
limited = len(results) > limit
- return {"limited": limited, "results": results[0:limit]}
+ return {
+ "limited": limited,
+ "results": [
+ {"user_id": r[0], "display_name": r[1], "avatar_url": r[2]}
+ for r in results[0:limit]
+ ],
+ }
def _filter_text_for_index(text: str) -> str:
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index 06fcbe5e54..8bd58c6e3d 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Iterable, Mapping
+from typing import Iterable, List, Mapping, Tuple, cast
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
@@ -50,14 +50,17 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore):
Returns:
for each user, whether the user has requested erasure.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="erased_users",
- column="user_id",
- iterable=user_ids,
- retcols=("user_id",),
- desc="are_users_erased",
+ rows = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="erased_users",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id",),
+ desc="are_users_erased",
+ ),
)
- erased_users = {row["user_id"] for row in rows}
+ erased_users = {row[0] for row in rows}
return {u: u in erased_users for u in user_ids}
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index d2e942cbd3..2c3151526d 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -359,7 +359,6 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
if max_group is None:
rows = await self.db_pool.execute(
"_background_deduplicate_state",
- None,
"SELECT coalesce(max(id), 0) FROM state_groups",
)
max_group = rows[0][0]
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 6984d11352..182e429174 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -13,7 +13,17 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ cast,
+)
import attr
@@ -144,16 +154,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
if not prev_group:
return _GetStateGroupDelta(None, None)
- delta_ids = self.db_pool.simple_select_list_txn(
- txn,
- table="state_groups_state",
- keyvalues={"state_group": state_group},
- retcols=("type", "state_key", "event_id"),
+ delta_ids = cast(
+ List[Tuple[str, str, str]],
+ self.db_pool.simple_select_list_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": state_group},
+ retcols=("type", "state_key", "event_id"),
+ ),
)
return _GetStateGroupDelta(
prev_group,
- {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
+ {
+ (event_type, state_key): event_id
+ for event_type, state_key, event_id in delta_ids
+ },
)
return await self.db_pool.runInteraction(
@@ -730,19 +746,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"[purge] found %i state groups to delete", len(state_groups_to_delete)
)
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="state_group_edges",
- column="prev_state_group",
- iterable=state_groups_to_delete,
- keyvalues={},
- retcols=("state_group",),
+ rows = cast(
+ List[Tuple[int]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ retcols=("state_group",),
+ ),
)
remaining_state_groups = {
- row["state_group"]
- for row in rows
- if row["state_group"] not in state_groups_to_delete
+ state_group
+ for state_group, in rows
+ if state_group not in state_groups_to_delete
}
logger.info(
@@ -799,16 +818,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
A mapping from state group to previous state group.
"""
- rows = await self.db_pool.simple_select_many_batch(
- table="state_group_edges",
- column="prev_state_group",
- iterable=state_groups,
- keyvalues={},
- retcols=("prev_state_group", "state_group"),
- desc="get_previous_state_groups",
+ rows = cast(
+ List[Tuple[int, int]],
+ await self.db_pool.simple_select_many_batch(
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("state_group", "prev_state_group"),
+ desc="get_previous_state_groups",
+ ),
)
- return {row["state_group"]: row["prev_state_group"] for row in rows}
+ return dict(rows)
async def purge_room_state(
self, room_id: str, state_groups_to_delete: Collection[int]
|