diff --git a/changelog.d/16551.misc b/changelog.d/16551.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16551.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 472879c964..c041b67993 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -19,6 +19,8 @@ import logging
import urllib.parse
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
+import attr
+
from synapse.api.errors import (
CodeMessageException,
Codes,
@@ -357,9 +359,9 @@ class IdentityHandler:
# Check to see if a session already exists and that it is not yet
# marked as validated
- if session and session.get("validated_at") is None:
- session_id = session["session_id"]
- last_send_attempt = session["last_send_attempt"]
+ if session and session.validated_at is None:
+ session_id = session.session_id
+ last_send_attempt = session.last_send_attempt
# Check that the send_attempt is higher than previous attempts
if send_attempt <= last_send_attempt:
@@ -480,7 +482,6 @@ class IdentityHandler:
# We don't actually know which medium this 3PID is. Thus we first assume it's email,
# and if validation fails we try msisdn
- validation_session = None
# Try to validate as email
if self.hs.config.email.can_verify_email:
@@ -488,19 +489,18 @@ class IdentityHandler:
validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
)
-
- if validation_session:
- return validation_session
+ if validation_session:
+ return attr.asdict(validation_session)
# Try to validate as msisdn
if self.hs.config.registration.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server
- validation_session = await self.threepid_from_creds(
+ return await self.threepid_from_creds(
self.hs.config.registration.account_threepid_delegate_msisdn,
threepid_creds,
)
- return validation_session
+ return None
async def proxy_msisdn_submit_token(
self, id_server: str, client_secret: str, sid: str, token: str
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 78a75bfed6..ab8f7610e9 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -187,9 +187,9 @@ class _BaseThreepidAuthChecker:
if row:
threepid = {
- "medium": row["medium"],
- "address": row["address"],
- "validated_at": row["validated_at"],
+ "medium": row.medium,
+ "address": row.address,
+ "validated_at": row.validated_at,
}
# Valid threepid returned, delete from the db
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 7fd46901f7..72b0f1c5de 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -949,10 +949,7 @@ class MediaRepository:
deleted = 0
- for media in old_media:
- origin = media["media_origin"]
- media_id = media["media_id"]
- file_id = media["filesystem_id"]
+ for origin, media_id, file_id in old_media:
key = (origin, media_id)
logger.info("Deleting: %r", key)
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
index 8a617af599..a6ce787da1 100644
--- a/synapse/rest/admin/federation.py
+++ b/synapse/rest/admin/federation.py
@@ -85,7 +85,19 @@ class ListDestinationsRestServlet(RestServlet):
destinations, total = await self._store.get_destinations_paginate(
start, limit, destination, order_by, direction
)
- response = {"destinations": destinations, "total": total}
+ response = {
+ "destinations": [
+ {
+ "destination": r[0],
+ "retry_last_ts": r[1],
+ "retry_interval": r[2],
+ "failure_ts": r[3],
+ "last_successful_stream_ordering": r[4],
+ }
+ for r in destinations
+ ],
+ "total": total,
+ }
if (start + limit) < total:
response["next_token"] = str(start + len(destinations))
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 436718c8b2..2d4da38db9 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -724,7 +724,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
room_id, _ = await self.resolve_room_id(room_identifier)
extremities = await self.store.get_forward_extremities_for_room(room_id)
- return HTTPStatus.OK, {"count": len(extremities), "results": extremities}
+ result = [
+ {
+ "event_id": ex[0],
+ "state_group": ex[1],
+ "depth": ex[2],
+ "received_ts": ex[3],
+ }
+ for ex in extremities
+ ]
+
+ return HTTPStatus.OK, {"count": len(extremities), "results": result}
class RoomEventContextServlet(RestServlet):
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index 19780e4b4c..75d8a37ccf 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -108,7 +108,18 @@ class UserMediaStatisticsRestServlet(RestServlet):
users_media, total = await self.store.get_users_media_usage_paginate(
start, limit, from_ts, until_ts, order_by, direction, search_term
)
- ret = {"users": users_media, "total": total}
+ ret = {
+ "users": [
+ {
+ "user_id": r[0],
+ "displayname": r[1],
+ "media_count": r[2],
+ "media_length": r[3],
+ }
+ for r in users_media
+ ],
+ "total": total,
+ }
if (start + limit) < total:
ret["next_token"] = start + len(users_media)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 774d5c12f0..b1ece63845 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -35,7 +35,6 @@ from typing import (
Tuple,
Type,
TypeVar,
- Union,
cast,
overload,
)
@@ -1047,43 +1046,20 @@ class DatabasePool:
results = [dict(zip(col_headers, row)) for row in cursor]
return results
- @overload
- async def execute(
- self, desc: str, decoder: Literal[None], query: str, *args: Any
- ) -> List[Tuple[Any, ...]]:
- ...
-
- @overload
- async def execute(
- self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
- ) -> R:
- ...
-
- async def execute(
- self,
- desc: str,
- decoder: Optional[Callable[[Cursor], R]],
- query: str,
- *args: Any,
- ) -> Union[List[Tuple[Any, ...]], R]:
+ async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]:
"""Runs a single query for a result set.
Args:
desc: description of the transaction, for logging and metrics
- decoder - The function which can resolve the cursor results to
- something meaningful.
query - The query string to execute
*args - Query args.
Returns:
The result of decoder(results)
"""
- def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]:
+ def interaction(txn: LoggingTransaction) -> List[Tuple[Any, ...]]:
txn.execute(query, args)
- if decoder:
- return decoder(txn)
- else:
- return txn.fetchall()
+ return txn.fetchall()
return await self.runInteraction(desc, interaction)
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 58177ecec1..711fdddd4e 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -93,7 +93,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
"""
rows = await self.db_pool.execute(
- "_censor_redactions_fetch", None, sql, before_ts, 100
+ "_censor_redactions_fetch", sql, before_ts, 100
)
updates = []
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 0b75f6763a..49edbb9e06 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -894,7 +894,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
rows = await self.db_pool.execute(
"get_all_devices_changed",
- None,
sql,
from_key,
to_key,
@@ -978,7 +977,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
WHERE from_user_id = ? AND stream_id > ?
"""
rows = await self.db_pool.execute(
- "get_users_whose_signatures_changed", None, sql, user_id, from_key
+ "get_users_whose_signatures_changed", sql, user_id, from_key
)
return {user for row in rows for user in db_to_json(row[0])}
else:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f13d776b0d..f70f95eeba 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -155,7 +155,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"""
rows = await self.db_pool.execute(
"get_e2e_device_keys_for_federation_query_check",
- None,
sql,
now_stream_id,
user_id,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index c5fce1c82b..0061805150 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1310,12 +1310,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
# indexes on it.
- # We need to pass execute a dummy function to handle the txn's result otherwise
- # it tries to call fetchall() on it and fails because there's no result to fetch.
- await self.db_pool.execute(
+ await self.db_pool.runInteraction(
"background_analyze_new_stream_ordering_column",
- lambda txn: None,
- "ANALYZE events(stream_ordering2)",
+ lambda txn: txn.execute("ANALYZE events(stream_ordering2)"),
)
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
index f851bff604..0ba84b1469 100644
--- a/synapse/storage/databases/main/events_forward_extremities.py
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, List
+from typing import List, Optional, Tuple, cast
from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
@@ -91,12 +91,17 @@ class EventForwardExtremitiesStore(
async def get_forward_extremities_for_room(
self, room_id: str
- ) -> List[Dict[str, Any]]:
- """Get list of forward extremities for a room."""
+ ) -> List[Tuple[str, int, int, Optional[int]]]:
+ """
+ Get list of forward extremities for a room.
+
+ Returns:
+ A list of tuples of event_id, state_group, depth, and received_ts.
+ """
def get_forward_extremities_for_room_txn(
txn: LoggingTransaction,
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[str, int, int, Optional[int]]]:
sql = """
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities
@@ -106,7 +111,7 @@ class EventForwardExtremitiesStore(
"""
txn.execute(sql, (room_id,))
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[Tuple[str, int, int, Optional[int]]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_forward_extremities_for_room",
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index f82140b2e8..aeb3db596c 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -650,7 +650,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_ids(
self, before_ts: int, include_quarantined_media: bool
- ) -> List[Dict[str, str]]:
+ ) -> List[Tuple[str, str, str]]:
"""
Retrieve a list of server name, media ID tuples from the remote media cache.
@@ -664,12 +664,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
A list of tuples containing:
* The server name of homeserver where the media originates from,
* The ID of the media.
+ * The filesystem ID.
+ """
+
+ sql = """
+ SELECT media_origin, media_id, filesystem_id
+ FROM remote_media_cache
+ WHERE last_access_ts < ?
"""
- sql = (
- "SELECT media_origin, media_id, filesystem_id"
- " FROM remote_media_cache"
- " WHERE last_access_ts < ?"
- )
if include_quarantined_media is False:
# Only include media that has not been quarantined
@@ -677,8 +679,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
AND quarantined_by IS NULL
"""
- return await self.db_pool.execute(
- "get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
+ return cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.execute("get_remote_media_ids", sql, before_ts),
)
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index b0ef7be155..e09ab21593 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -151,6 +151,22 @@ class ThreepidResult:
added_at: int
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class ThreepidValidationSession:
+ address: str
+ """address of the 3pid"""
+ medium: str
+ """medium of the 3pid"""
+ client_secret: str
+ """a secret provided by the client for this validation session"""
+ session_id: str
+ """ID of the validation session"""
+ last_send_attempt: int
+ """a number serving to dedupe send attempts for this session"""
+ validated_at: Optional[int]
+ """timestamp of when this session was validated if so"""
+
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -1172,7 +1188,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
address: Optional[str] = None,
sid: Optional[str] = None,
validated: Optional[bool] = True,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[ThreepidValidationSession]:
"""Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata
@@ -1187,15 +1203,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
perform no filtering
Returns:
- A dict containing the following:
- * address - address of the 3pid
- * medium - medium of the 3pid
- * client_secret - a secret provided by the client for this validation session
- * session_id - ID of the validation session
- * send_attempt - a number serving to dedupe send attempts for this session
- * validated_at - timestamp of when this session was validated if so
-
- Otherwise None if a validation session is not found
+ A ThreepidValidationSession or None if a validation session is not found
"""
if not client_secret:
raise SynapseError(
@@ -1214,7 +1222,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def get_threepid_validation_session_txn(
txn: LoggingTransaction,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[ThreepidValidationSession]:
sql = """
SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at
@@ -1229,11 +1237,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
sql += " LIMIT 1"
txn.execute(sql, list(keyvalues.values()))
- rows = self.db_pool.cursor_to_dict(txn)
- if not rows:
+ row = txn.fetchone()
+ if not row:
return None
- return rows[0]
+ return ThreepidValidationSession(
+ address=row[0],
+ session_id=row[1],
+ medium=row[2],
+ client_secret=row[3],
+ last_send_attempt=row[4],
+ validated_at=row[5],
+ )
return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index a1627dffb7..67e149b586 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -940,7 +940,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
like_clause = "%:" + host
rows = await self.db_pool.execute(
- "is_host_joined", None, sql, membership, room_id, like_clause
+ "is_host_joined", sql, membership, room_id, like_clause
)
if not rows:
@@ -1168,7 +1168,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
AND forgotten = 0;
"""
- rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+ rows = await self.db_pool.execute("is_forgotten_room", sql, room_id)
# `count(*)` returns always an integer
# If any rows still exist it means someone has not forgotten this room yet
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 1d69c4a5f0..dbde9130c6 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -26,6 +26,7 @@ from typing import (
Set,
Tuple,
Union,
+ cast,
)
import attr
@@ -506,16 +507,18 @@ class SearchStore(SearchBackgroundUpdateStore):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
- results = await self.db_pool.execute(
- "search_msgs", self.db_pool.cursor_to_dict, sql, *args
+ # List of tuples of (rank, room_id, event_id).
+ results = cast(
+ List[Tuple[Union[int, float], str, str]],
+ await self.db_pool.execute("search_msgs", sql, *args),
)
- results = list(filter(lambda row: row["room_id"] in room_ids, results))
+ results = list(filter(lambda row: row[1] in room_ids, results))
# We set redact_behaviour to block here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list( # type: ignore[attr-defined]
- [r["event_id"] for r in results],
+ [r[2] for r in results],
redact_behaviour=EventRedactBehaviour.block,
)
@@ -527,16 +530,18 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = await self.db_pool.execute(
- "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
+ # List of tuples of (room_id, count).
+ count_results = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
)
- count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+ count = sum(row[1] for row in count_results if row[0] in room_ids)
return {
"results": [
- {"event": event_map[r["event_id"]], "rank": r["rank"]}
+ {"event": event_map[r[2]], "rank": r[0]}
for r in results
- if r["event_id"] in event_map
+ if r[2] in event_map
],
"highlights": highlights,
"count": count,
@@ -604,7 +609,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = search_term
sql = """
SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
- origin_server_ts, stream_ordering, room_id, event_id
+ room_id, event_id, origin_server_ts, stream_ordering
FROM event_search
WHERE vector @@ websearch_to_tsquery('english', ?) AND
"""
@@ -665,16 +670,18 @@ class SearchStore(SearchBackgroundUpdateStore):
# mypy expects to append only a `str`, not an `int`
args.append(limit)
- results = await self.db_pool.execute(
- "search_rooms", self.db_pool.cursor_to_dict, sql, *args
+ # List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering).
+ results = cast(
+ List[Tuple[Union[int, float], str, str, int, int]],
+ await self.db_pool.execute("search_rooms", sql, *args),
)
- results = list(filter(lambda row: row["room_id"] in room_ids, results))
+ results = list(filter(lambda row: row[1] in room_ids, results))
# We set redact_behaviour to block here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list( # type: ignore[attr-defined]
- [r["event_id"] for r in results],
+ [r[2] for r in results],
redact_behaviour=EventRedactBehaviour.block,
)
@@ -686,22 +693,23 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = await self.db_pool.execute(
- "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
+ # List of tuples of (room_id, count).
+ count_results = cast(
+ List[Tuple[str, int]],
+ await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
)
- count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
+ count = sum(row[1] for row in count_results if row[0] in room_ids)
return {
"results": [
{
- "event": event_map[r["event_id"]],
- "rank": r["rank"],
- "pagination_token": "%s,%s"
- % (r["origin_server_ts"], r["stream_ordering"]),
+ "event": event_map[r[2]],
+ "rank": r[0],
+ "pagination_token": "%s,%s" % (r[3], r[4]),
}
for r in results
- if r["event_id"] in event_map
+ if r[2] in event_map
],
"highlights": highlights,
"count": count,
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 5b2d0ba870..e96c9b0486 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -679,7 +679,7 @@ class StatsStore(StateDeltasStore):
order_by: Optional[str] = UserSortOrder.USER_ID.value,
direction: Direction = Direction.FORWARDS,
search_term: Optional[str] = None,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
"""Function to retrieve a paginated list of users and their uploaded local media
(size and number). This will return a json list of users and the
total number of users matching the filter criteria.
@@ -692,14 +692,19 @@ class StatsStore(StateDeltasStore):
order_by: the sort order of the returned list
direction: sort ascending or descending
search_term: a string to filter user names by
+
Returns:
- A list of user dicts and an integer representing the total number of
- users that exist given this query
+ A tuple of:
+ A list of tuples of user information (the user ID, displayname,
+ total number of media, total length of media) and
+
+ An integer representing the total number of users that exist
+ given this query
"""
def get_users_media_usage_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
filters = []
args: list = []
@@ -773,7 +778,7 @@ class StatsStore(StateDeltasStore):
args += [limit, start]
txn.execute(sql, args)
- users = self.db_pool.cursor_to_dict(txn)
+ users = cast(List[Tuple[str, Optional[str], int, int]], txn.fetchall())
return users, count
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 872df6bda1..2225f8272d 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1078,7 +1078,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
row = await self.db_pool.execute(
- "get_current_topological_token", None, sql, room_id, room_id, stream_key
+ "get_current_topological_token", sql, room_id, room_id, stream_key
)
return row[0][0] if row else 0
@@ -1636,7 +1636,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = await self.db_pool.execute(
"get_timeline_gaps",
- None,
sql,
room_id,
from_token.stream if from_token else 0,
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index c4a6475060..fecddb4144 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -478,7 +478,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
destination: Optional[str] = None,
order_by: str = DestinationSortOrder.DESTINATION.value,
direction: Direction = Direction.FORWARDS,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[
+ List[Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]],
+ int,
+ ]:
"""Function to retrieve a paginated list of destinations.
This will return a json list of destinations and the
total number of destinations matching the filter criteria.
@@ -490,13 +493,23 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
order_by: the sort order of the returned list
direction: sort ascending or descending
Returns:
- A tuple of a list of mappings from destination to information
+ A tuple of a list of tuples of destination information:
+ * destination
+ * retry_last_ts
+ * retry_interval
+ * failure_ts
+ * last_successful_stream_ordering
and a count of total destinations.
"""
def get_destinations_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[
+ List[
+ Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]
+ ],
+ int,
+ ]:
order_by_column = DestinationSortOrder(order_by).value
if direction == Direction.BACKWARDS:
@@ -523,7 +536,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
LIMIT ? OFFSET ?
"""
txn.execute(sql, args + [limit, start])
- destinations = self.db_pool.cursor_to_dict(txn)
+ destinations = cast(
+ List[
+ Tuple[
+ str, Optional[int], Optional[int], Optional[int], Optional[int]
+ ]
+ ],
+ txn.fetchall(),
+ )
return destinations, count
return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 23eb92c514..a9f5d68b63 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -1145,15 +1145,19 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
raise Exception("Unrecognized database engine")
results = cast(
- List[UserProfile],
- await self.db_pool.execute(
- "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
- ),
+ List[Tuple[str, Optional[str], Optional[str]]],
+ await self.db_pool.execute("search_user_dir", sql, *args),
)
limited = len(results) > limit
- return {"limited": limited, "results": results[0:limit]}
+ return {
+ "limited": limited,
+ "results": [
+ {"user_id": r[0], "display_name": r[1], "avatar_url": r[2]}
+ for r in results[0:limit]
+ ],
+ }
def _filter_text_for_index(text: str) -> str:
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 6ff533a129..0f9c550b27 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -359,7 +359,6 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
if max_group is None:
rows = await self.db_pool.execute(
"_background_deduplicate_state",
- None,
"SELECT coalesce(max(id), 0) FROM state_groups",
)
max_group = rows[0][0]
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 75ae740b43..08214b0013 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -100,7 +100,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
event_id, stream_ordering = self.get_success(
self.hs.get_datastores().main.db_pool.execute(
"test:get_destination_rooms",
- None,
"""
SELECT event_id, stream_ordering
FROM destination_rooms dr
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 3f5bfa09d4..67ea640902 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -457,8 +457,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
);
"""
self.get_success(
- self.store.db_pool.execute(
- "test_not_null_constraint", lambda _: None, table_sql
+ self.store.db_pool.runInteraction(
+ "test_not_null_constraint", lambda txn: txn.execute(table_sql)
)
)
@@ -466,8 +466,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
# using SQLite.
index_sql = "CREATE INDEX test_index ON test_constraint(a)"
self.get_success(
- self.store.db_pool.execute(
- "test_not_null_constraint", lambda _: None, index_sql
+ self.store.db_pool.runInteraction(
+ "test_not_null_constraint", lambda txn: txn.execute(index_sql)
)
)
@@ -574,13 +574,13 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
);
"""
self.get_success(
- self.store.db_pool.execute(
- "test_foreign_key_constraint", lambda _: None, base_sql
+ self.store.db_pool.runInteraction(
+ "test_foreign_key_constraint", lambda txn: txn.execute(base_sql)
)
)
self.get_success(
- self.store.db_pool.execute(
- "test_foreign_key_constraint", lambda _: None, table_sql
+ self.store.db_pool.runInteraction(
+ "test_foreign_key_constraint", lambda txn: txn.execute(table_sql)
)
)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 95f99f4130..6afb5403bd 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -120,7 +120,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
res = self.get_success(
self.store.db_pool.execute(
- "", None, "SELECT full_user_id from profiles ORDER BY full_user_id"
+ "", "SELECT full_user_id from profiles ORDER BY full_user_id"
)
)
self.assertEqual(len(res), len(expected_values))
diff --git a/tests/storage/test_user_filters.py b/tests/storage/test_user_filters.py
index d4637d9d1e..2da6a018e8 100644
--- a/tests/storage/test_user_filters.py
+++ b/tests/storage/test_user_filters.py
@@ -87,7 +87,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase):
res = self.get_success(
self.store.db_pool.execute(
- "", None, "SELECT full_user_id from user_filters ORDER BY full_user_id"
+ "", "SELECT full_user_id from user_filters ORDER BY full_user_id"
)
)
self.assertEqual(len(res), len(expected_values))
|