diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 39498d52c6..d7482a1f4e 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -94,7 +94,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
hs.get_replication_notifier(),
"room_account_data",
"stream_id",
- extra_tables=[("room_tags_revisions", "stream_id")],
+ extra_tables=[
+ ("account_data", "stream_id"),
+ ("room_tags_revisions", "stream_id"),
+ ],
is_writer=self._instance_name in hs.config.worker.writers.account_data,
)
@@ -283,16 +286,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def get_account_data_for_room_txn(
txn: LoggingTransaction,
- ) -> Dict[str, JsonDict]:
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "room_account_data",
- {"user_id": user_id, "room_id": room_id},
- ["account_data_type", "content"],
+ ) -> Dict[str, JsonMapping]:
+ rows = cast(
+ List[Tuple[str, str]],
+ self.db_pool.simple_select_list_txn(
+ txn,
+ table="room_account_data",
+ keyvalues={"user_id": user_id, "room_id": room_id},
+ retcols=["account_data_type", "content"],
+ ),
)
return {
- row["account_data_type"]: db_to_json(row["content"]) for row in rows
+ account_data_type: db_to_json(content)
+ for account_data_type, content in rows
}
return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 073a99cd84..fa7d1c469a 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -197,16 +197,21 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A list of ApplicationServices, which may be empty.
"""
- results = await self.db_pool.simple_select_list(
- "application_services_state", {"state": state.value}, ["as_id"]
+ results = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_list(
+ table="application_services_state",
+ keyvalues={"state": state.value},
+ retcols=("as_id",),
+ ),
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
services = []
- for res in results:
+ for (as_id,) in results:
for service in as_list:
- if service.id == res["as_id"]:
+ if service.id == as_id:
services.append(service)
return services
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2fbd389c71..4d0470ffd9 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -23,6 +23,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.replication.tcp.streams import BackfillStream, CachesStream
from synapse.replication.tcp.streams.events import (
EventsStream,
+ EventsStreamAllStateRow,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
@@ -264,6 +265,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
(data.state_key,)
)
self.get_rooms_for_user.invalidate((data.state_key,)) # type: ignore[attr-defined]
+ elif row.type == EventsStreamAllStateRow.TypeId:
+ assert isinstance(data, EventsStreamAllStateRow)
+ # Similar to the above, but the entire caches are invalidated. This is
+ # unfortunate for the membership caches, but should recover quickly.
+ self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) # type: ignore[attr-defined]
+ self.get_rooms_for_user_with_stream_ordering.invalidate_all() # type: ignore[attr-defined]
+ self.get_rooms_for_user.invalidate_all() # type: ignore[attr-defined]
else:
raise Exception("Unknown events stream row type %s" % (row.type,))
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 58177ecec1..711fdddd4e 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -93,7 +93,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
"""
rows = await self.db_pool.execute(
- "_censor_redactions_fetch", None, sql, before_ts, 100
+ "_censor_redactions_fetch", sql, before_ts, 100
)
updates = []
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index bf5b8c804b..4ea56331c7 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -508,21 +508,24 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
if device_id is not None:
keyvalues["device_id"] = device_id
- res = await self.db_pool.simple_select_list(
- table="devices",
- keyvalues=keyvalues,
- retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ res = cast(
+ List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
+ await self.db_pool.simple_select_list(
+ table="devices",
+ keyvalues=keyvalues,
+ retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ ),
)
return {
- (d["user_id"], d["device_id"]): DeviceLastConnectionInfo(
- user_id=d["user_id"],
- device_id=d["device_id"],
- ip=d["ip"],
- user_agent=d["user_agent"],
- last_seen=d["last_seen"],
+ (user_id, device_id): DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id,
+ ip=ip,
+ user_agent=user_agent,
+ last_seen=last_seen,
)
- for d in res
+ for user_id, ip, user_agent, device_id, last_seen in res
}
async def _get_user_ip_and_agents_from_database(
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 72dc4f54dc..71eefe6b7c 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -478,18 +478,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
log_kv({"message": "No changes in cache since last check"})
return 0
- ROW_ID_NAME = self.database_engine.row_id_name
-
def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
limit_statement = "" if limit is None else f"LIMIT {limit}"
sql = f"""
- DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN (
- SELECT {ROW_ID_NAME} FROM device_inbox
- WHERE user_id = ? AND device_id = ? AND stream_id <= ?
- {limit_statement}
+ DELETE FROM device_inbox WHERE user_id = ? AND device_id = ? AND stream_id <= (
+ SELECT MAX(stream_id) FROM (
+ SELECT stream_id FROM device_inbox
+ WHERE user_id = ? AND device_id = ? AND stream_id <= ?
+ ORDER BY stream_id
+ {limit_statement}
+ ) AS q1
)
"""
- txn.execute(sql, (user_id, device_id, up_to_stream_id))
+ txn.execute(sql, (user_id, device_id, user_id, device_id, up_to_stream_id))
return txn.rowcount
count = await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a07086149c..ae0536fbaf 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -285,7 +285,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
allow_none=True,
)
- async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
+ async def get_devices_by_user(
+ self, user_id: str
+ ) -> Dict[str, Dict[str, Optional[str]]]:
"""Retrieve all of a user's registered devices. Only returns devices
that are not marked as hidden.
@@ -293,20 +295,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_id:
Returns:
A mapping from device_id to a dict containing "device_id", "user_id"
- and "display_name" for each device.
+ and "display_name" for each device. Display name may be null.
"""
- devices = await self.db_pool.simple_select_list(
- table="devices",
- keyvalues={"user_id": user_id, "hidden": False},
- retcols=("user_id", "device_id", "display_name"),
- desc="get_devices_by_user",
+ devices = cast(
+ List[Tuple[str, str, Optional[str]]],
+ await self.db_pool.simple_select_list(
+ table="devices",
+ keyvalues={"user_id": user_id, "hidden": False},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_devices_by_user",
+ ),
)
- return {d["device_id"]: d for d in devices}
+ return {
+ d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]}
+ for d in devices
+ }
async def get_devices_by_auth_provider_session_id(
self, auth_provider_id: str, auth_provider_session_id: str
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[str, str]]:
"""Retrieve the list of devices associated with a SSO IdP session ID.
Args:
@@ -315,14 +323,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
Returns:
A list of dicts containing the device_id and the user_id of each device
"""
- return await self.db_pool.simple_select_list(
- table="device_auth_providers",
- keyvalues={
- "auth_provider_id": auth_provider_id,
- "auth_provider_session_id": auth_provider_session_id,
- },
- retcols=("user_id", "device_id"),
- desc="get_devices_by_auth_provider_session_id",
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="device_auth_providers",
+ keyvalues={
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
+ },
+ retcols=("user_id", "device_id"),
+ desc="get_devices_by_auth_provider_session_id",
+ ),
)
@trace
@@ -823,15 +834,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
async def get_cached_devices_for_user(
self, user_id: str
) -> Mapping[str, JsonMapping]:
- devices = await self.db_pool.simple_select_list(
- table="device_lists_remote_cache",
- keyvalues={"user_id": user_id},
- retcols=("device_id", "content"),
- desc="get_cached_devices_for_user",
+ devices = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="device_lists_remote_cache",
+ keyvalues={"user_id": user_id},
+ retcols=("device_id", "content"),
+ desc="get_cached_devices_for_user",
+ ),
)
- return {
- device["device_id"]: db_to_json(device["content"]) for device in devices
- }
+ return {device[0]: db_to_json(device[1]) for device in devices}
def get_cached_device_list_changes(
self,
@@ -884,7 +896,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
rows = await self.db_pool.execute(
"get_all_devices_changed",
- None,
sql,
from_key,
to_key,
@@ -968,7 +979,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
WHERE from_user_id = ? AND stream_id > ?
"""
rows = await self.db_pool.execute(
- "get_users_whose_signatures_changed", None, sql, user_id, from_key
+ "get_users_whose_signatures_changed", sql, user_id, from_key
)
return {user for row in rows for user in db_to_json(row[0])}
else:
@@ -1082,7 +1093,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
The IDs of users whose device lists need resync.
"""
if user_ids:
- row_tuples = cast(
+ rows = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync",
@@ -1092,11 +1103,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
desc="get_user_ids_requiring_device_list_resync_with_iterable",
),
)
-
- return {row[0] for row in row_tuples}
else:
rows = cast(
- List[Dict[str, str]],
+ List[Tuple[str]],
await self.db_pool.simple_select_list(
table="device_lists_remote_resync",
keyvalues=None,
@@ -1105,7 +1114,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
),
)
- return {row["user_id"] for row in rows}
+ return {row[0] for row in rows}
async def mark_remote_users_device_caches_as_stale(
self, user_ids: StrCollection
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index aac4cfb054..ad904a26a6 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast
from typing_extensions import Literal, TypedDict
@@ -274,32 +274,41 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
if session_id:
keyvalues["session_id"] = session_id
- rows = await self.db_pool.simple_select_list(
- table="e2e_room_keys",
- keyvalues=keyvalues,
- retcols=(
- "user_id",
- "room_id",
- "session_id",
- "first_message_index",
- "forwarded_count",
- "is_verified",
- "session_data",
+ rows = cast(
+ List[Tuple[str, str, int, int, int, str]],
+ await self.db_pool.simple_select_list(
+ table="e2e_room_keys",
+ keyvalues=keyvalues,
+ retcols=(
+ "room_id",
+ "session_id",
+ "first_message_index",
+ "forwarded_count",
+ "is_verified",
+ "session_data",
+ ),
+ desc="get_e2e_room_keys",
),
- desc="get_e2e_room_keys",
)
sessions: Dict[
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
] = {"rooms": {}}
- for row in rows:
- room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
- room_entry["sessions"][row["session_id"]] = {
- "first_message_index": row["first_message_index"],
- "forwarded_count": row["forwarded_count"],
+ for (
+ room_id,
+ session_id,
+ first_message_index,
+ forwarded_count,
+ is_verified,
+ session_data,
+ ) in rows:
+ room_entry = sessions["rooms"].setdefault(room_id, {"sessions": {}})
+ room_entry["sessions"][session_id] = {
+ "first_message_index": first_message_index,
+ "forwarded_count": forwarded_count,
# is_verified must be returned to the client as a boolean
- "is_verified": bool(row["is_verified"]),
- "session_data": db_to_json(row["session_data"]),
+ "is_verified": bool(is_verified),
+ "session_data": db_to_json(session_data),
}
return sessions
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f13d776b0d..f70f95eeba 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -155,7 +155,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"""
rows = await self.db_pool.execute(
"get_e2e_device_keys_for_federation_query_check",
- None,
sql,
now_stream_id,
user_id,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4f80ce75cc..f1b0991503 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1898,21 +1898,23 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# keeping only the forward extremities (i.e. the events not referenced
# by other events in the queue). We do this so that we can always
# backpaginate in all the events we have dropped.
- rows = await self.db_pool.simple_select_list(
- table="federation_inbound_events_staging",
- keyvalues={"room_id": room_id},
- retcols=("event_id", "event_json"),
- desc="prune_staged_events_in_room_fetch",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="federation_inbound_events_staging",
+ keyvalues={"room_id": room_id},
+ retcols=("event_id", "event_json"),
+ desc="prune_staged_events_in_room_fetch",
+ ),
)
# Find the set of events referenced by those in the queue, as well as
# collecting all the event IDs in the queue.
referenced_events: Set[str] = set()
seen_events: Set[str] = set()
- for row in rows:
- event_id = row["event_id"]
+ for event_id, event_json in rows:
seen_events.add(event_id)
- event_d = db_to_json(row["event_json"])
+ event_d = db_to_json(event_json)
# We don't bother parsing the dicts into full blown event objects,
# as that is needlessly expensive.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ef6766b5e0..3c1492e3ad 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -2267,35 +2267,59 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
- # From the events passed in, add all of the prev events as backwards extremities.
- # Ignore any events that are already backwards extrems or outliers.
- query = (
- "INSERT INTO event_backward_extremities (event_id, room_id)"
- " SELECT ?, ? WHERE NOT EXISTS ("
- " SELECT 1 FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- " )"
- # 1. Don't add an event as a extremity again if we already persisted it
- # as a non-outlier.
- # 2. Don't add an outlier as an extremity if it has no prev_events
- " AND NOT EXISTS ("
- " SELECT 1 FROM events"
- " LEFT JOIN event_edges edge"
- " ON edge.event_id = events.event_id"
- " WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = FALSE OR edge.event_id IS NULL)"
- " )"
+
+ room_id = events[0].room_id
+
+ potential_backwards_extremities = {
+ e_id
+ for ev in events
+ for e_id in ev.prev_event_ids()
+ if not ev.internal_metadata.is_outlier()
+ }
+
+ if not potential_backwards_extremities:
+ return
+
+ existing_events_outliers = self.db_pool.simple_select_many_txn(
+ txn,
+ table="events",
+ column="event_id",
+ iterable=potential_backwards_extremities,
+ keyvalues={"outlier": False},
+ retcols=("event_id",),
)
- txn.execute_batch(
- query,
- [
- (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id)
- for ev in events
- for e_id in ev.prev_event_ids()
- if not ev.internal_metadata.is_outlier()
- ],
+ potential_backwards_extremities.difference_update(
+ e for e, in existing_events_outliers
)
+ if potential_backwards_extremities:
+ self.db_pool.simple_upsert_many_txn(
+ txn,
+ table="event_backward_extremities",
+ key_names=("room_id", "event_id"),
+ key_values=[(room_id, ev) for ev in potential_backwards_extremities],
+ value_names=(),
+ value_values=(),
+ )
+
+ # Record the stream orderings where we have new gaps.
+ gap_events = [
+ (room_id, self._instance_name, ev.internal_metadata.stream_ordering)
+ for ev in events
+ if any(
+ e_id in potential_backwards_extremities
+ for e_id in ev.prev_event_ids()
+ )
+ ]
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="timeline_gaps",
+ keys=("room_id", "instance_name", "stream_ordering"),
+ values=gap_events,
+ )
+
# Delete all these events that we've already fetched and now know that their
# prev events are the new backwards extremeties.
query = (
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index c5fce1c82b..0061805150 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1310,12 +1310,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
# indexes on it.
- # We need to pass execute a dummy function to handle the txn's result otherwise
- # it tries to call fetchall() on it and fails because there's no result to fetch.
- await self.db_pool.execute(
+ await self.db_pool.runInteraction(
"background_analyze_new_stream_ordering_column",
- lambda txn: None,
- "ANALYZE events(stream_ordering2)",
+ lambda txn: txn.execute("ANALYZE events(stream_ordering2)"),
)
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
index f851bff604..0ba84b1469 100644
--- a/synapse/storage/databases/main/events_forward_extremities.py
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, List
+from typing import List, Optional, Tuple, cast
from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
@@ -91,12 +91,17 @@ class EventForwardExtremitiesStore(
async def get_forward_extremities_for_room(
self, room_id: str
- ) -> List[Dict[str, Any]]:
- """Get list of forward extremities for a room."""
+ ) -> List[Tuple[str, int, int, Optional[int]]]:
+ """
+ Get list of forward extremities for a room.
+
+ Returns:
+ A list of tuples of event_id, state_group, depth, and received_ts.
+ """
def get_forward_extremities_for_room_txn(
txn: LoggingTransaction,
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[str, int, int, Optional[int]]]:
sql = """
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities
@@ -106,7 +111,7 @@ class EventForwardExtremitiesStore(
"""
txn.execute(sql, (room_id,))
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[Tuple[str, int, int, Optional[int]]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_forward_extremities_for_room",
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 89757eabed..9a6c905617 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -2096,12 +2096,6 @@ class EventsWorkerStore(SQLBaseStore):
def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
sql = """
- DELETE FROM event_txn_id
- WHERE inserted_ts < ?
- """
- txn.execute(sql, (one_day_ago,))
-
- sql = """
DELETE FROM event_txn_id_device_id
WHERE inserted_ts < ?
"""
diff --git a/synapse/storage/databases/main/experimental_features.py b/synapse/storage/databases/main/experimental_features.py
index 654f924019..60621edeef 100644
--- a/synapse/storage/databases/main/experimental_features.py
+++ b/synapse/storage/databases/main/experimental_features.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Dict, FrozenSet
+from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main import CacheInvalidationWorkerStore
@@ -42,13 +42,16 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
Returns:
the features currently enabled for the user
"""
- enabled = await self.db_pool.simple_select_list(
- "per_user_experimental_features",
- {"user_id": user_id, "enabled": True},
- ["feature"],
+ enabled = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_list(
+ table="per_user_experimental_features",
+ keyvalues={"user_id": user_id, "enabled": True},
+ retcols=("feature",),
+ ),
)
- return frozenset(feature["feature"] for feature in enabled)
+ return frozenset(feature[0] for feature in enabled)
async def set_features_for_user(
self,
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index ea797864b9..ce88772f9e 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -248,17 +248,20 @@ class KeyStore(CacheInvalidationWorkerStore):
If we have multiple entries for a given key ID, returns the most recent.
"""
- rows = await self.db_pool.simple_select_list(
- table="server_keys_json",
- keyvalues={"server_name": server_name},
- retcols=(
- "key_id",
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "key_json",
+ rows = cast(
+ List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
+ await self.db_pool.simple_select_list(
+ table="server_keys_json",
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
),
- desc="get_server_keys_json_for_remote",
)
if not rows:
@@ -266,14 +269,14 @@ class KeyStore(CacheInvalidationWorkerStore):
# We sort the rows by ts_added_ms so that the most recently added entry
# will stomp over older entries in the dictionary.
- rows.sort(key=lambda r: r["ts_added_ms"])
+ rows.sort(key=lambda r: r[2])
return {
- row["key_id"]: FetchKeyResultForRemote(
+ key_id: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
- key_json=bytes(row["key_json"]),
- valid_until_ts=row["ts_valid_until_ms"],
- added_ts=row["ts_added_ms"],
+ key_json=bytes(key_json),
+ valid_until_ts=ts_valid_until_ms,
+ added_ts=ts_added_ms,
)
- for row in rows
+ for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
}
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 2e6b176bd2..aeb3db596c 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -437,25 +437,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]:
- rows = await self.db_pool.simple_select_list(
- "local_media_repository_thumbnails",
- {"media_id": media_id},
- (
- "thumbnail_width",
- "thumbnail_height",
- "thumbnail_method",
- "thumbnail_type",
- "thumbnail_length",
+ rows = cast(
+ List[Tuple[int, int, str, str, int]],
+ await self.db_pool.simple_select_list(
+ "local_media_repository_thumbnails",
+ {"media_id": media_id},
+ (
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ ),
+ desc="get_local_media_thumbnails",
),
- desc="get_local_media_thumbnails",
)
return [
ThumbnailInfo(
- width=row["thumbnail_width"],
- height=row["thumbnail_height"],
- method=row["thumbnail_method"],
- type=row["thumbnail_type"],
- length=row["thumbnail_length"],
+ width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)
for row in rows
]
@@ -568,25 +567,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_thumbnails(
self, origin: str, media_id: str
) -> List[ThumbnailInfo]:
- rows = await self.db_pool.simple_select_list(
- "remote_media_cache_thumbnails",
- {"media_origin": origin, "media_id": media_id},
- (
- "thumbnail_width",
- "thumbnail_height",
- "thumbnail_method",
- "thumbnail_type",
- "thumbnail_length",
+ rows = cast(
+ List[Tuple[int, int, str, str, int]],
+ await self.db_pool.simple_select_list(
+ "remote_media_cache_thumbnails",
+ {"media_origin": origin, "media_id": media_id},
+ (
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ ),
+ desc="get_remote_media_thumbnails",
),
- desc="get_remote_media_thumbnails",
)
return [
ThumbnailInfo(
- width=row["thumbnail_width"],
- height=row["thumbnail_height"],
- method=row["thumbnail_method"],
- type=row["thumbnail_type"],
- length=row["thumbnail_length"],
+ width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)
for row in rows
]
@@ -652,7 +650,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_ids(
self, before_ts: int, include_quarantined_media: bool
- ) -> List[Dict[str, str]]:
+ ) -> List[Tuple[str, str, str]]:
"""
Retrieve a list of server name, media ID tuples from the remote media cache.
@@ -666,12 +664,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
A list of tuples containing:
* The server name of homeserver where the media originates from,
* The ID of the media.
+ * The filesystem ID.
+ """
+
+ sql = """
+ SELECT media_origin, media_id, filesystem_id
+ FROM remote_media_cache
+ WHERE last_access_ts < ?
"""
- sql = (
- "SELECT media_origin, media_id, filesystem_id"
- " FROM remote_media_cache"
- " WHERE last_access_ts < ?"
- )
if include_quarantined_media is False:
# Only include media that has not been quarantined
@@ -679,8 +679,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
AND quarantined_by IS NULL
"""
- return await self.db_pool.execute(
- "get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
+ return cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.execute("get_remote_media_ids", sql, before_ts),
)
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index f5356e7f80..22025eca56 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -179,46 +179,44 @@ class PushRulesWorkerStore(
@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
- rows = await self.db_pool.simple_select_list(
- table="push_rules",
- keyvalues={"user_name": user_id},
- retcols=(
- "user_name",
- "rule_id",
- "priority_class",
- "priority",
- "conditions",
- "actions",
+ rows = cast(
+ List[Tuple[str, int, int, str, str]],
+ await self.db_pool.simple_select_list(
+ table="push_rules",
+ keyvalues={"user_name": user_id},
+ retcols=(
+ "rule_id",
+ "priority_class",
+ "priority",
+ "conditions",
+ "actions",
+ ),
+ desc="get_push_rules_for_user",
),
- desc="get_push_rules_for_user",
)
- rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
+ # Sort by highest priority_class, then highest priority.
+ rows.sort(key=lambda row: (-int(row[1]), -int(row[2])))
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
return _load_rules(
- [
- (
- row["rule_id"],
- row["priority_class"],
- row["conditions"],
- row["actions"],
- )
- for row in rows
- ],
+ [(row[0], row[1], row[3], row[4]) for row in rows],
enabled_map,
self.hs.config.experimental,
)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
- results = await self.db_pool.simple_select_list(
- table="push_rules_enable",
- keyvalues={"user_name": user_id},
- retcols=("rule_id", "enabled"),
- desc="get_push_rules_enabled_for_user",
+ results = cast(
+ List[Tuple[str, Optional[Union[int, bool]]]],
+ await self.db_pool.simple_select_list(
+ table="push_rules_enable",
+ keyvalues={"user_name": user_id},
+ retcols=("rule_id", "enabled"),
+ desc="get_push_rules_enabled_for_user",
+ ),
)
- return {r["rule_id"]: bool(r["enabled"]) for r in results}
+ return {r[0]: bool(r[1]) for r in results}
async def have_push_rules_changed_for_user(
self, user_id: str, last_id: int
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index c7eb7fc478..a6a1671bd6 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -371,18 +371,20 @@ class PusherWorkerStore(SQLBaseStore):
async def get_throttle_params_by_room(
self, pusher_id: int
) -> Dict[str, ThrottleParams]:
- res = await self.db_pool.simple_select_list(
- "pusher_throttle",
- {"pusher": pusher_id},
- ["room_id", "last_sent_ts", "throttle_ms"],
- desc="get_throttle_params_by_room",
+ res = cast(
+ List[Tuple[str, Optional[int], Optional[int]]],
+ await self.db_pool.simple_select_list(
+ "pusher_throttle",
+ {"pusher": pusher_id},
+ ["room_id", "last_sent_ts", "throttle_ms"],
+ desc="get_throttle_params_by_room",
+ ),
)
params_by_room = {}
- for row in res:
- params_by_room[row["room_id"]] = ThrottleParams(
- row["last_sent_ts"],
- row["throttle_ms"],
+ for room_id, last_sent_ts, throttle_ms in res:
+ params_by_room[room_id] = ThrottleParams(
+ last_sent_ts or 0, throttle_ms or 0
)
return params_by_room
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index b2645ab43c..56e8eb16a8 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -28,6 +28,8 @@ from typing import (
cast,
)
+from immutabledict import immutabledict
+
from synapse.api.constants import EduTypes
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -43,7 +45,12 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
StreamIdGenerator,
)
-from synapse.types import JsonDict, JsonMapping
+from synapse.types import (
+ JsonDict,
+ JsonMapping,
+ MultiWriterStreamToken,
+ PersistedPosition,
+)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -105,7 +112,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"receipts_linearized",
entity_column="room_id",
stream_column="stream_id",
- max_value=max_receipts_stream_id,
+ max_value=max_receipts_stream_id.stream,
limit=10000,
)
self._receipts_stream_cache = StreamChangeCache(
@@ -114,9 +121,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
prefilled_cache=receipts_stream_prefill,
)
- def get_max_receipt_stream_id(self) -> int:
+ def get_max_receipt_stream_id(self) -> MultiWriterStreamToken:
"""Get the current max stream ID for receipts stream"""
- return self._receipts_id_gen.get_current_token()
+
+ min_pos = self._receipts_id_gen.get_current_token()
+
+ positions = {}
+ if isinstance(self._receipts_id_gen, MultiWriterIdGenerator):
+ # The `min_pos` is the minimum position that we know all instances
+ # have finished persisting to, so we only care about instances whose
+ # positions are ahead of that. (Instance positions can be behind the
+ # min position as there are times we can work out that the minimum
+ # position is ahead of the naive minimum across all current
+ # positions. See MultiWriterIdGenerator for details)
+ positions = {
+ i: p
+ for i, p in self._receipts_id_gen.get_positions().items()
+ if p > min_pos
+ }
+
+ return MultiWriterStreamToken(
+ stream=min_pos, instance_map=immutabledict(positions)
+ )
+
+ def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
+ return self._receipts_id_gen.get_current_token_for_writer(instance_name)
def get_last_unthreaded_receipt_for_user_txn(
self,
@@ -257,7 +286,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
async def get_linearized_receipts_for_rooms(
- self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
+ self,
+ room_ids: Iterable[str],
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> List[JsonMapping]:
"""Get receipts for multiple rooms for sending to clients.
@@ -276,7 +308,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# Only ask the database about rooms where there have been new
# receipts added since `from_key`
room_ids = self._receipts_stream_cache.get_entities_changed(
- room_ids, from_key
+ room_ids, from_key.stream
)
results = await self._get_linearized_receipts_for_rooms(
@@ -286,7 +318,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return [ev for res in results.values() for ev in res]
async def get_linearized_receipts_for_room(
- self, room_id: str, to_key: int, from_key: Optional[int] = None
+ self,
+ room_id: str,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Sequence[JsonMapping]:
"""Get receipts for a single room for sending to clients.
@@ -302,36 +337,49 @@ class ReceiptsWorkerStore(SQLBaseStore):
if from_key is not None:
# Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op.
- if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
+ if not self._receipts_stream_cache.has_entity_changed(
+ room_id, from_key.stream
+ ):
return []
return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
@cached(tree=True)
async def _get_linearized_receipts_for_room(
- self, room_id: str, to_key: int, from_key: Optional[int] = None
+ self,
+ room_id: str,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Sequence[JsonMapping]:
"""See get_linearized_receipts_for_room"""
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
if from_key:
- sql = (
- "SELECT receipt_type, user_id, event_id, data"
- " FROM receipts_linearized WHERE"
- " room_id = ? AND stream_id > ? AND stream_id <= ?"
- )
+ sql = """
+ SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
+ FROM receipts_linearized
+ WHERE room_id = ? AND stream_id > ? AND stream_id <= ?
+ """
- txn.execute(sql, (room_id, from_key, to_key))
- else:
- sql = (
- "SELECT receipt_type, user_id, event_id, data"
- " FROM receipts_linearized WHERE"
- " room_id = ? AND stream_id <= ?"
+ txn.execute(
+ sql, (room_id, from_key.stream, to_key.get_max_stream_pos())
)
+ else:
+ sql = """
+ SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
+ FROM receipts_linearized WHERE
+ room_id = ? AND stream_id <= ?
+ """
- txn.execute(sql, (room_id, to_key))
+ txn.execute(sql, (room_id, to_key.get_max_stream_pos()))
- return cast(List[Tuple[str, str, str, str]], txn.fetchall())
+ return [
+ (receipt_type, user_id, event_id, data)
+ for stream_id, instance_name, receipt_type, user_id, event_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
@@ -352,7 +400,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
num_args=3,
)
async def _get_linearized_receipts_for_rooms(
- self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
+ self,
+ room_ids: Collection[str],
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Mapping[str, Sequence[JsonMapping]]:
if not room_ids:
return {}
@@ -362,7 +413,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
) -> List[Tuple[str, str, str, str, Optional[str], str]]:
if from_key:
sql = """
- SELECT room_id, receipt_type, user_id, event_id, thread_id, data
+ SELECT stream_id, instance_name, room_id, receipt_type,
+ user_id, event_id, thread_id, data
FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? AND
"""
@@ -370,10 +422,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
self.database_engine, "room_id", room_ids
)
- txn.execute(sql + clause, [from_key, to_key] + list(args))
+ txn.execute(
+ sql + clause,
+ [from_key.stream, to_key.get_max_stream_pos()] + list(args),
+ )
else:
sql = """
- SELECT room_id, receipt_type, user_id, event_id, thread_id, data
+ SELECT stream_id, instance_name, room_id, receipt_type,
+ user_id, event_id, thread_id, data
FROM receipts_linearized WHERE
stream_id <= ? AND
"""
@@ -382,11 +438,15 @@ class ReceiptsWorkerStore(SQLBaseStore):
self.database_engine, "room_id", room_ids
)
- txn.execute(sql + clause, [to_key] + list(args))
+ txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args))
- return cast(
- List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall()
- )
+ return [
+ (room_id, receipt_type, user_id, event_id, thread_id, data)
+ for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
@@ -420,7 +480,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
num_args=2,
)
async def get_linearized_receipts_for_all_rooms(
- self, to_key: int, from_key: Optional[int] = None
+ self,
+ to_key: MultiWriterStreamToken,
+ from_key: Optional[MultiWriterStreamToken] = None,
) -> Mapping[str, JsonMapping]:
"""Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts.
@@ -437,25 +499,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
if from_key:
sql = """
- SELECT room_id, receipt_type, user_id, event_id, data
+ SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
"""
- txn.execute(sql, [from_key, to_key])
+ txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()])
else:
sql = """
- SELECT room_id, receipt_type, user_id, event_id, data
+ SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized WHERE
stream_id <= ?
ORDER BY stream_id DESC
LIMIT 100
"""
- txn.execute(sql, [to_key])
+ txn.execute(sql, [to_key.get_max_stream_pos()])
- return cast(List[Tuple[str, str, str, str, str]], txn.fetchall())
+ return [
+ (room_id, receipt_type, user_id, event_id, data)
+ for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ from_key, to_key, instance_name, stream_id
+ )
+ ]
txn_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_all_rooms", f
@@ -545,10 +613,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ?
+ AND instance_name = ?
ORDER BY stream_id ASC
LIMIT ?
"""
- txn.execute(sql, (last_id, current_id, limit))
+ txn.execute(sql, (last_id, current_id, instance_name, limit))
updates = cast(
List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
@@ -695,6 +764,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
keyvalues=keyvalues,
values={
"stream_id": stream_id,
+ "instance_name": self._instance_name,
"event_id": event_id,
"event_stream_ordering": stream_ordering,
"data": json_encoder.encode(data),
@@ -750,7 +820,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
event_ids: List[str],
thread_id: Optional[str],
data: dict,
- ) -> Optional[int]:
+ ) -> Optional[PersistedPosition]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
@@ -812,7 +882,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
data,
)
- return stream_id
+ return PersistedPosition(self._instance_name, stream_id)
async def _insert_graph_receipt(
self,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 9e8643ae4d..e09ab21593 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -151,6 +151,22 @@ class ThreepidResult:
added_at: int
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class ThreepidValidationSession:
+ address: str
+ """address of the 3pid"""
+ medium: str
+ """medium of the 3pid"""
+ client_secret: str
+ """a secret provided by the client for this validation session"""
+ session_id: str
+ """ID of the validation session"""
+ last_send_attempt: int
+ """a number serving to dedupe send attempts for this session"""
+ validated_at: Optional[int]
+ """timestamp of when this session was validated if so"""
+
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -855,13 +871,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns:
Tuples of (auth_provider, external_id)
"""
- res = await self.db_pool.simple_select_list(
- table="user_external_ids",
- keyvalues={"user_id": mxid},
- retcols=("auth_provider", "external_id"),
- desc="get_external_ids_by_user",
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="user_external_ids",
+ keyvalues={"user_id": mxid},
+ retcols=("auth_provider", "external_id"),
+ desc="get_external_ids_by_user",
+ ),
)
- return [(r["auth_provider"], r["external_id"]) for r in res]
async def count_all_users(self) -> int:
"""Counts all users registered on the homeserver."""
@@ -997,13 +1015,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
- results = await self.db_pool.simple_select_list(
- "user_threepids",
- keyvalues={"user_id": user_id},
- retcols=["medium", "address", "validated_at", "added_at"],
- desc="user_get_threepids",
+ results = cast(
+ List[Tuple[str, str, int, int]],
+ await self.db_pool.simple_select_list(
+ "user_threepids",
+ keyvalues={"user_id": user_id},
+ retcols=["medium", "address", "validated_at", "added_at"],
+ desc="user_get_threepids",
+ ),
)
- return [ThreepidResult(**r) for r in results]
+ return [
+ ThreepidResult(
+ medium=r[0],
+ address=r[1],
+ validated_at=r[2],
+ added_at=r[3],
+ )
+ for r in results
+ ]
async def user_delete_threepid(
self, user_id: str, medium: str, address: str
@@ -1042,7 +1071,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="add_user_bound_threepid",
)
- async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
+ async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]:
"""Get the threepids that a user has bound to an identity server through the homeserver
The homeserver remembers where binds to an identity server occurred. Using this
method can retrieve those threepids.
@@ -1051,15 +1080,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_id: The ID of the user to retrieve threepids for
Returns:
- List of dictionaries containing the following keys:
- medium (str): The medium of the threepid (e.g "email")
- address (str): The address of the threepid (e.g "bob@example.com")
- """
- return await self.db_pool.simple_select_list(
- table="user_threepid_id_server",
- keyvalues={"user_id": user_id},
- retcols=["medium", "address"],
- desc="user_get_bound_threepids",
+ List of tuples of two strings:
+ medium: The medium of the threepid (e.g "email")
+ address: The address of the threepid (e.g "bob@example.com")
+ """
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="user_threepid_id_server",
+ keyvalues={"user_id": user_id},
+ retcols=["medium", "address"],
+ desc="user_get_bound_threepids",
+ ),
)
async def remove_user_bound_threepid(
@@ -1156,7 +1188,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
address: Optional[str] = None,
sid: Optional[str] = None,
validated: Optional[bool] = True,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[ThreepidValidationSession]:
"""Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata
@@ -1171,15 +1203,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
perform no filtering
Returns:
- A dict containing the following:
- * address - address of the 3pid
- * medium - medium of the 3pid
- * client_secret - a secret provided by the client for this validation session
- * session_id - ID of the validation session
- * send_attempt - a number serving to dedupe send attempts for this session
- * validated_at - timestamp of when this session was validated if so
-
- Otherwise None if a validation session is not found
+ A ThreepidValidationSession or None if a validation session is not found
"""
if not client_secret:
raise SynapseError(
@@ -1198,7 +1222,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def get_threepid_validation_session_txn(
txn: LoggingTransaction,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[ThreepidValidationSession]:
sql = """
SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at
@@ -1213,11 +1237,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
sql += " LIMIT 1"
txn.execute(sql, list(keyvalues.values()))
- rows = self.db_pool.cursor_to_dict(txn)
- if not rows:
+ row = txn.fetchone()
+ if not row:
return None
- return rows[0]
+ return ThreepidValidationSession(
+ address=row[0],
+ session_id=row[1],
+ medium=row[2],
+ client_secret=row[3],
+ last_send_attempt=row[4],
+ validated_at=row[5],
+ )
return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 7f40e2c446..419b2c7a22 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -47,7 +47,7 @@ from synapse.storage.databases.main.stream import (
generate_pagination_where_clause,
)
from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, StreamKeyType, StreamToken
+from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -314,7 +314,7 @@ class RelationsWorkerStore(SQLBaseStore):
room_key=next_key,
presence_key=0,
typing_key=0,
- receipt_key=0,
+ receipt_key=MultiWriterStreamToken(stream=0),
account_data_key=0,
push_rules_key=0,
to_device_key=0,
@@ -384,14 +384,17 @@ class RelationsWorkerStore(SQLBaseStore):
def get_all_relation_ids_for_event_txn(
txn: LoggingTransaction,
) -> List[str]:
- rows = self.db_pool.simple_select_list_txn(
- txn=txn,
- table="event_relations",
- keyvalues={"relates_to_id": event_id},
- retcols=["event_id"],
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_list_txn(
+ txn=txn,
+ table="event_relations",
+ keyvalues={"relates_to_id": event_id},
+ retcols=["event_id"],
+ ),
)
- return [row["event_id"] for row in rows]
+ return [row[0] for row in rows]
return await self.db_pool.runInteraction(
desc="get_all_relation_ids_for_event",
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 9d24d2c347..3e8fcf1975 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1232,28 +1232,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
"""
room_servers: Dict[str, PartialStateResyncInfo] = {}
- rows = await self.db_pool.simple_select_list(
- table="partial_state_rooms",
- keyvalues={},
- retcols=("room_id", "joined_via"),
- desc="get_server_which_served_partial_join",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="partial_state_rooms",
+ keyvalues={},
+ retcols=("room_id", "joined_via"),
+ desc="get_server_which_served_partial_join",
+ ),
)
- for row in rows:
- room_id = row["room_id"]
- joined_via = row["joined_via"]
+ for room_id, joined_via in rows:
room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via)
- rows = await self.db_pool.simple_select_list(
- "partial_state_rooms_servers",
- keyvalues=None,
- retcols=("room_id", "server_name"),
- desc="get_partial_state_rooms",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ "partial_state_rooms_servers",
+ keyvalues=None,
+ retcols=("room_id", "server_name"),
+ desc="get_partial_state_rooms",
+ ),
)
- for row in rows:
- room_id = row["room_id"]
- server_name = row["server_name"]
+ for room_id, server_name in rows:
entry = room_servers.get(room_id)
if entry is None:
# There is a foreign key constraint which enforces that every room_id in
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 3a87eba430..67e149b586 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -940,7 +940,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
like_clause = "%:" + host
rows = await self.db_pool.execute(
- "is_host_joined", None, sql, membership, room_id, like_clause
+ "is_host_joined", sql, membership, room_id, like_clause
)
if not rows:
@@ -1070,13 +1070,16 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
for fully-joined rooms.
"""
- rows = await self.db_pool.simple_select_list(
- "current_state_events",
- keyvalues={"room_id": room_id},
- retcols=("event_id", "membership"),
- desc="has_completed_background_updates",
+ rows = cast(
+ List[Tuple[str, Optional[str]]],
+ await self.db_pool.simple_select_list(
+ "current_state_events",
+ keyvalues={"room_id": room_id},
+ retcols=("event_id", "membership"),
+ desc="has_completed_background_updates",
+ ),
)
- return {row["event_id"]: row["membership"] for row in rows}
+ return dict(rows)
# TODO This returns a mutable object, which is generally confusing when using a cache.
@cached(max_entries=10000) # type: ignore[synapse-@cached-mutable]
@@ -1165,7 +1168,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
AND forgotten = 0;
"""
- rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+ rows = await self.db_pool.execute("is_forgotten_room", sql, room_id)
# `count(*)` returns always an integer
# If any rows still exist it means someone has not forgotten this room yet
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 1d69c4a5f0..dbde9130c6 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -26,6 +26,7 @@ from typing import (
Set,
Tuple,
Union,
+ cast,
)
import attr
@@ -506,16 +507,18 @@ class SearchStore(SearchBackgroundUpdateStore):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
- results = await self.db_pool.execute(
- "search_msgs", self.db_pool.cursor_to_dict, sql, *args
+ # List of tuples of (rank, room_id, event_id).
+ results = cast(
+ List[Tuple[Union[int, float], str, str]],
+ await self.db_pool.execute("search_msgs", sql, *args),
)
- results = list(filter(lambda row: row["room_id"] in room_ids, results))
+ results = list(filter(lambda row: row[1] in room_ids, results))
# We set redact_behaviour to block here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list( # type: ignore[attr-defined]
- [r["event_id"] for r in results],
+ [r[2] for r in results],
redact_behaviour=EventRedactBehaviour.block,
)
@@ -527,16 +530,18 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = await self.db_pool.execute(
- "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
+ # List of tuples of (room_id, count).
+ count_results = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
)
- count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+ count = sum(row[1] for row in count_results if row[0] in room_ids)
return {
"results": [
- {"event": event_map[r["event_id"]], "rank": r["rank"]}
+ {"event": event_map[r[2]], "rank": r[0]}
for r in results
- if r["event_id"] in event_map
+ if r[2] in event_map
],
"highlights": highlights,
"count": count,
@@ -604,7 +609,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = search_term
sql = """
SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
- origin_server_ts, stream_ordering, room_id, event_id
+ room_id, event_id, origin_server_ts, stream_ordering
FROM event_search
WHERE vector @@ websearch_to_tsquery('english', ?) AND
"""
@@ -665,16 +670,18 @@ class SearchStore(SearchBackgroundUpdateStore):
# mypy expects to append only a `str`, not an `int`
args.append(limit)
- results = await self.db_pool.execute(
- "search_rooms", self.db_pool.cursor_to_dict, sql, *args
+ # List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering).
+ results = cast(
+ List[Tuple[Union[int, float], str, str, int, int]],
+ await self.db_pool.execute("search_rooms", sql, *args),
)
- results = list(filter(lambda row: row["room_id"] in room_ids, results))
+ results = list(filter(lambda row: row[1] in room_ids, results))
# We set redact_behaviour to block here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list( # type: ignore[attr-defined]
- [r["event_id"] for r in results],
+ [r[2] for r in results],
redact_behaviour=EventRedactBehaviour.block,
)
@@ -686,22 +693,23 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = await self.db_pool.execute(
- "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
+ # List of tuples of (room_id, count).
+ count_results = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
)
- count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+ count = sum(row[1] for row in count_results if row[0] in room_ids)
return {
"results": [
{
- "event": event_map[r["event_id"]],
- "rank": r["rank"],
- "pagination_token": "%s,%s"
- % (r["origin_server_ts"], r["stream_ordering"]),
+ "event": event_map[r[2]],
+ "rank": r[0],
+ "pagination_token": "%s,%s" % (r[3], r[4]),
}
for r in results
- if r["event_id"] in event_map
+ if r[2] in event_map
],
"highlights": highlights,
"count": count,
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 5b2d0ba870..e96c9b0486 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -679,7 +679,7 @@ class StatsStore(StateDeltasStore):
order_by: Optional[str] = UserSortOrder.USER_ID.value,
direction: Direction = Direction.FORWARDS,
search_term: Optional[str] = None,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
"""Function to retrieve a paginated list of users and their uploaded local media
(size and number). This will return a json list of users and the
total number of users matching the filter criteria.
@@ -692,14 +692,19 @@ class StatsStore(StateDeltasStore):
order_by: the sort order of the returned list
direction: sort ascending or descending
search_term: a string to filter user names by
+
Returns:
- A list of user dicts and an integer representing the total number of
- users that exist given this query
+ A tuple of:
+ A list of tuples of user information (the user ID, displayname,
+ total number of media, total length of media) and
+
+ An integer representing the total number of users that exist
+ given this query
"""
def get_users_media_usage_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
filters = []
args: list = []
@@ -773,7 +778,7 @@ class StatsStore(StateDeltasStore):
args += [limit, start]
txn.execute(sql, args)
- users = self.db_pool.cursor_to_dict(txn)
+ users = cast(List[Tuple[str, Optional[str], int, int]], txn.fetchall())
return users, count
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index ea06e4eee0..2225f8272d 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1078,7 +1078,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
row = await self.db_pool.execute(
- "get_current_topological_token", None, sql, room_id, room_id, stream_key
+ "get_current_topological_token", sql, room_id, room_id, stream_key
)
return row[0][0] if row else 0
@@ -1616,3 +1616,49 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcol="instance_name",
desc="get_name_from_instance_id",
)
+
+ async def get_timeline_gaps(
+ self,
+ room_id: str,
+ from_token: Optional[RoomStreamToken],
+ to_token: RoomStreamToken,
+ ) -> Optional[RoomStreamToken]:
+ """Check if there is a gap, and return a token that marks the position
+ of the gap in the stream.
+ """
+
+ sql = """
+ SELECT instance_name, stream_ordering
+ FROM timeline_gaps
+ WHERE room_id = ? AND ? < stream_ordering AND stream_ordering <= ?
+ ORDER BY stream_ordering
+ """
+
+ rows = await self.db_pool.execute(
+ "get_timeline_gaps",
+ sql,
+ room_id,
+ from_token.stream if from_token else 0,
+ to_token.get_max_stream_pos(),
+ )
+
+ if not rows:
+ return None
+
+ positions = [
+ PersistedEventPosition(instance_name, stream_ordering)
+ for instance_name, stream_ordering in rows
+ ]
+ if from_token:
+ positions = [p for p in positions if p.persisted_after(from_token)]
+
+ positions = [p for p in positions if not p.persisted_after(to_token)]
+
+ if positions:
+ # We return a stream token that ensures the event *at* the position
+ # of the gap is included (as the gap is *before* the persisted
+ # event).
+ last_position = positions[-1]
+ return RoomStreamToken(stream=last_position.stream - 1)
+
+ return None
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 61403a98cf..7deda7790e 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -45,14 +45,17 @@ class TagsWorkerStore(AccountDataWorkerStore):
tag content.
"""
- rows = await self.db_pool.simple_select_list(
- "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
+ rows = cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.simple_select_list(
+ "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
+ ),
)
tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
- for row in rows:
- room_tags = tags_by_room.setdefault(row["room_id"], {})
- room_tags[row["tag"]] = db_to_json(row["content"])
+ for room_id, tag, content in rows:
+ room_tags = tags_by_room.setdefault(room_id, {})
+ room_tags[tag] = db_to_json(content)
return tags_by_room
async def get_all_updated_tags(
@@ -161,13 +164,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
Returns:
A mapping of tags to tag content.
"""
- rows = await self.db_pool.simple_select_list(
- table="room_tags",
- keyvalues={"user_id": user_id, "room_id": room_id},
- retcols=("tag", "content"),
- desc="get_tags_for_room",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="room_tags",
+ keyvalues={"user_id": user_id, "room_id": room_id},
+ retcols=("tag", "content"),
+ desc="get_tags_for_room",
+ ),
)
- return {row["tag"]: db_to_json(row["content"]) for row in rows}
+ return {tag: db_to_json(content) for tag, content in rows}
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index c4a6475060..fecddb4144 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -478,7 +478,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
destination: Optional[str] = None,
order_by: str = DestinationSortOrder.DESTINATION.value,
direction: Direction = Direction.FORWARDS,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[
+ List[Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]],
+ int,
+ ]:
"""Function to retrieve a paginated list of destinations.
This will return a json list of destinations and the
total number of destinations matching the filter criteria.
@@ -490,13 +493,23 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
order_by: the sort order of the returned list
direction: sort ascending or descending
Returns:
- A tuple of a list of mappings from destination to information
+ A tuple of a list of tuples of destination information:
+ * destination
+ * retry_last_ts
+ * retry_interval
+ * failure_ts
+ * last_successful_stream_ordering
and a count of total destinations.
"""
def get_destinations_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[
+ List[
+ Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]
+ ],
+ int,
+ ]:
order_by_column = DestinationSortOrder(order_by).value
if direction == Direction.BACKWARDS:
@@ -523,7 +536,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
LIMIT ? OFFSET ?
"""
txn.execute(sql, args + [limit, start])
- destinations = self.db_pool.cursor_to_dict(txn)
+ destinations = cast(
+ List[
+ Tuple[
+ str, Optional[int], Optional[int], Optional[int], Optional[int]
+ ]
+ ],
+ txn.fetchall(),
+ )
return destinations, count
return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 919c66f553..8ab7c42c4a 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -169,13 +169,17 @@ class UIAuthWorkerStore(SQLBaseStore):
that auth-type.
"""
results = {}
- for row in await self.db_pool.simple_select_list(
- table="ui_auth_sessions_credentials",
- keyvalues={"session_id": session_id},
- retcols=("stage_type", "result"),
- desc="get_completed_ui_auth_stages",
- ):
- results[row["stage_type"]] = db_to_json(row["result"])
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_credentials",
+ keyvalues={"session_id": session_id},
+ retcols=("stage_type", "result"),
+ desc="get_completed_ui_auth_stages",
+ ),
+ )
+ for stage_type, result in rows:
+ results[stage_type] = db_to_json(result)
return results
@@ -295,13 +299,15 @@ class UIAuthWorkerStore(SQLBaseStore):
Returns:
List of user_agent/ip pairs
"""
- rows = await self.db_pool.simple_select_list(
- table="ui_auth_sessions_ips",
- keyvalues={"session_id": session_id},
- retcols=("user_agent", "ip"),
- desc="get_user_agents_ips_to_ui_auth_session",
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id},
+ retcols=("user_agent", "ip"),
+ desc="get_user_agents_ips_to_ui_auth_session",
+ ),
)
- return [(row["user_agent"], row["ip"]) for row in rows]
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 23eb92c514..a9f5d68b63 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -1145,15 +1145,19 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
raise Exception("Unrecognized database engine")
results = cast(
- List[UserProfile],
- await self.db_pool.execute(
- "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
- ),
+ List[Tuple[str, Optional[str], Optional[str]]],
+ await self.db_pool.execute("search_user_dir", sql, *args),
)
limited = len(results) > limit
- return {"limited": limited, "results": results[0:limit]}
+ return {
+ "limited": limited,
+ "results": [
+ {"user_id": r[0], "display_name": r[1], "avatar_url": r[2]}
+ for r in results[0:limit]
+ ],
+ }
def _filter_text_for_index(text: str) -> str:
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 6ff533a129..0f9c550b27 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -359,7 +359,6 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
if max_group is None:
rows = await self.db_pool.execute(
"_background_deduplicate_state",
- None,
"SELECT coalesce(max(id), 0) FROM state_groups",
)
max_group = rows[0][0]
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 09d2a8c5b3..182e429174 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -154,16 +154,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
if not prev_group:
return _GetStateGroupDelta(None, None)
- delta_ids = self.db_pool.simple_select_list_txn(
- txn,
- table="state_groups_state",
- keyvalues={"state_group": state_group},
- retcols=("type", "state_key", "event_id"),
+ delta_ids = cast(
+ List[Tuple[str, str, str]],
+ self.db_pool.simple_select_list_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": state_group},
+ retcols=("type", "state_key", "event_id"),
+ ),
)
return _GetStateGroupDelta(
prev_group,
- {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
+ {
+ (event_type, state_key): event_id
+ for event_type, state_key, event_id in delta_ids
+ },
)
return await self.db_pool.runInteraction(
|