diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 8364ac8bc9..bb86419028 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -348,8 +348,7 @@ class Porter:
backward_chunk = 0
already_ported = 0
else:
- forward_chunk = row["forward_rowid"]
- backward_chunk = row["backward_rowid"]
+ forward_chunk, backward_chunk = row
if total_to_port is None:
already_ported, total_to_port = await self._get_total_count_to_port(
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index c2109036ec..1027fbfd28 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
import random
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Union
from synapse.api.errors import (
AuthError,
@@ -23,6 +23,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
+from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import parse_and_validate_mxc_uri
@@ -306,7 +307,9 @@ class ProfileHandler:
server_name = host
if self._is_mine_server_name(server_name):
- media_info = await self.store.get_local_media(media_id)
+ media_info: Optional[
+ Union[LocalMedia, RemoteMedia]
+ ] = await self.store.get_local_media(media_id)
else:
media_info = await self.store.get_cached_remote_media(server_name, media_id)
@@ -322,12 +325,12 @@ class ProfileHandler:
if self.max_avatar_size:
# Ensure avatar does not exceed max allowed avatar size
- if media_info["media_length"] > self.max_avatar_size:
+ if media_info.media_length > self.max_avatar_size:
logger.warning(
"Forbidding avatar change to %s: %d bytes is above the allowed size "
"limit",
mxc,
- media_info["media_length"],
+ media_info.media_length,
)
return False
@@ -335,12 +338,12 @@ class ProfileHandler:
# Ensure the avatar's file type is allowed
if (
self.allowed_avatar_mimetypes
- and media_info["media_type"] not in self.allowed_avatar_mimetypes
+ and media_info.media_type not in self.allowed_avatar_mimetypes
):
logger.warning(
"Forbidding avatar change to %s: mimetype %s not allowed",
mxc,
- media_info["media_type"],
+ media_info.media_type,
)
return False
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 6d680b0795..afd8138caf 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -269,7 +269,7 @@ class RoomCreationHandler:
self,
requester: Requester,
old_room_id: str,
- old_room: Dict[str, Any],
+ old_room: Tuple[bool, str, bool],
new_room_id: str,
new_version: RoomVersion,
tombstone_event: EventBase,
@@ -279,7 +279,7 @@ class RoomCreationHandler:
Args:
requester: the user requesting the upgrade
old_room_id: the id of the room to be replaced
- old_room: a dict containing room information for the room to be replaced,
+ old_room: a tuple containing room information for the room to be replaced,
as returned by `RoomWorkerStore.get_room`.
new_room_id: the id of the replacement room
new_version: the version to upgrade the room to
@@ -299,7 +299,7 @@ class RoomCreationHandler:
await self.store.store_room(
room_id=new_room_id,
room_creator_user_id=user_id,
- is_public=old_room["is_public"],
+ is_public=old_room[0],
room_version=new_version,
)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 918eb203e2..eddc2af9ba 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1260,7 +1260,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Add new room to the room directory if the old room was there
# Remove old room from the room directory
old_room = await self.store.get_room(old_room_id)
- if old_room is not None and old_room["is_public"]:
+ # If the old room exists and is public.
+ if old_room is not None and old_room[0]:
await self.store.set_room_is_public(old_room_id, False)
await self.store.set_room_is_public(room_id, True)
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 62f2454f5d..389dc5298a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -806,7 +806,7 @@ class SsoHandler:
media_id = profile["avatar_url"].split("/")[-1]
if self._is_mine_server_name(server_name):
media = await self._media_repo.store.get_local_media(media_id)
- if media is not None and upload_name == media["upload_name"]:
+ if media is not None and upload_name == media.upload_name:
logger.info("skipping saving the user avatar")
return True
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 72b0f1c5de..1957426c6a 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -19,6 +19,7 @@ import shutil
from io import BytesIO
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
+import attr
from matrix_common.types.mxc_uri import MXCUri
import twisted.internet.error
@@ -50,6 +51,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.media_repository import RemoteMedia
from synapse.types import UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
@@ -245,18 +247,18 @@ class MediaRepository:
Resolves once a response has successfully been written to request
"""
media_info = await self.store.get_local_media(media_id)
- if not media_info or media_info["quarantined_by"]:
+ if not media_info or media_info.quarantined_by:
respond_404(request)
return
self.mark_recently_accessed(None, media_id)
- media_type = media_info["media_type"]
+ media_type = media_info.media_type
if not media_type:
media_type = "application/octet-stream"
- media_length = media_info["media_length"]
- upload_name = name if name else media_info["upload_name"]
- url_cache = media_info["url_cache"]
+ media_length = media_info.media_length
+ upload_name = name if name else media_info.upload_name
+ url_cache = media_info.url_cache
file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
@@ -310,16 +312,20 @@ class MediaRepository:
# We deliberately stream the file outside the lock
if responder:
- media_type = media_info["media_type"]
- media_length = media_info["media_length"]
- upload_name = name if name else media_info["upload_name"]
+ upload_name = name if name else media_info.upload_name
await respond_with_responder(
- request, responder, media_type, media_length, upload_name
+ request,
+ responder,
+ media_info.media_type,
+ media_info.media_length,
+ upload_name,
)
else:
respond_404(request)
- async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
+ async def get_remote_media_info(
+ self, server_name: str, media_id: str
+ ) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading
if necessary.
@@ -353,7 +359,7 @@ class MediaRepository:
async def _get_remote_media_impl(
self, server_name: str, media_id: str
- ) -> Tuple[Optional[Responder], dict]:
+ ) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -373,15 +379,17 @@ class MediaRepository:
# If we have an entry in the DB, try and look for it
if media_info:
- file_id = media_info["filesystem_id"]
+ file_id = media_info.filesystem_id
file_info = FileInfo(server_name, file_id)
- if media_info["quarantined_by"]:
+ if media_info.quarantined_by:
logger.info("Media is quarantined")
raise NotFoundError()
- if not media_info["media_type"]:
- media_info["media_type"] = "application/octet-stream"
+ if not media_info.media_type:
+ media_info = attr.evolve(
+ media_info, media_type="application/octet-stream"
+ )
responder = await self.media_storage.fetch_media(file_info)
if responder:
@@ -403,9 +411,9 @@ class MediaRepository:
if not media_info:
raise e
- file_id = media_info["filesystem_id"]
- if not media_info["media_type"]:
- media_info["media_type"] = "application/octet-stream"
+ file_id = media_info.filesystem_id
+ if not media_info.media_type:
+ media_info = attr.evolve(media_info, media_type="application/octet-stream")
file_info = FileInfo(server_name, file_id)
# We generate thumbnails even if another process downloaded the media
@@ -415,7 +423,7 @@ class MediaRepository:
# otherwise they'll request thumbnails and get a 404 if they're not
# ready yet.
await self._generate_thumbnails(
- server_name, media_id, file_id, media_info["media_type"]
+ server_name, media_id, file_id, media_info.media_type
)
responder = await self.media_storage.fetch_media(file_info)
@@ -425,7 +433,7 @@ class MediaRepository:
self,
server_name: str,
media_id: str,
- ) -> dict:
+ ) -> RemoteMedia:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -518,7 +526,7 @@ class MediaRepository:
origin=server_name,
media_id=media_id,
media_type=media_type,
- time_now_ms=self.clock.time_msec(),
+ time_now_ms=time_now_ms,
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
@@ -526,15 +534,17 @@ class MediaRepository:
logger.info("Stored remote media in file %r", fname)
- media_info = {
- "media_type": media_type,
- "media_length": length,
- "upload_name": upload_name,
- "created_ts": time_now_ms,
- "filesystem_id": file_id,
- }
-
- return media_info
+ return RemoteMedia(
+ media_origin=server_name,
+ media_id=media_id,
+ media_type=media_type,
+ media_length=length,
+ upload_name=upload_name,
+ created_ts=time_now_ms,
+ filesystem_id=file_id,
+ last_access_ts=time_now_ms,
+ quarantined_by=None,
+ )
def _get_thumbnail_requirements(
self, media_type: str
diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py
index 9b5a3dd5f4..44aac21de6 100644
--- a/synapse/media/url_previewer.py
+++ b/synapse/media/url_previewer.py
@@ -240,15 +240,14 @@ class UrlPreviewer:
cache_result = await self.store.get_url_cache(url, ts)
if (
cache_result
- and cache_result["expires_ts"] > ts
- and cache_result["response_code"] / 100 == 2
+ and cache_result.expires_ts > ts
+ and cache_result.response_code // 100 == 2
):
# It may be stored as text in the database, not as bytes (such as
# PostgreSQL). If so, encode it back before handing it on.
- og = cache_result["og"]
- if isinstance(og, str):
- og = og.encode("utf8")
- return og
+ if isinstance(cache_result.og, str):
+ return cache_result.og.encode("utf8")
+ return cache_result.og
# If this URL can be accessed via an allowed oEmbed, use that instead.
url_to_download = url
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 755c59274c..812144a128 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -1860,7 +1860,8 @@ class PublicRoomListManager:
if not room:
return False
- return room.get("is_public", False)
+ # The first item is whether the room is public.
+ return room[0]
async def add_room_to_public_room_list(self, room_id: str) -> None:
"""Publishes a room to the public room list.
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 23a034522c..7e40bea8aa 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -413,8 +413,8 @@ class RoomMembersRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- ret = await self.store.get_room(room_id)
- if not ret:
+ room = await self.store.get_room(room_id)
+ if not room:
raise NotFoundError("Room not found")
members = await self.store.get_users_in_room(room_id)
@@ -442,8 +442,8 @@ class RoomStateRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- ret = await self.store.get_room(room_id)
- if not ret:
+ room = await self.store.get_room(room_id)
+ if not room:
raise NotFoundError("Room not found")
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index 82944ca711..3534c3c259 100644
--- a/synapse/rest/client/directory.py
+++ b/synapse/rest/client/directory.py
@@ -147,7 +147,7 @@ class ClientDirectoryListServer(RestServlet):
if room is None:
raise NotFoundError("Unknown room")
- return 200, {"visibility": "public" if room["is_public"] else "private"}
+ return 200, {"visibility": "public" if room[0] else "private"}
class PutBody(RequestBodyModel):
visibility: Literal["public", "private"] = "public"
diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index 85b6bdbe72..efda8b4ab4 100644
--- a/synapse/rest/media/thumbnail_resource.py
+++ b/synapse/rest/media/thumbnail_resource.py
@@ -119,7 +119,7 @@ class ThumbnailResource(RestServlet):
if not media_info:
respond_404(request)
return
- if media_info["quarantined_by"]:
+ if media_info.quarantined_by:
logger.info("Media is quarantined")
respond_404(request)
return
@@ -134,7 +134,7 @@ class ThumbnailResource(RestServlet):
thumbnail_infos,
media_id,
media_id,
- url_cache=bool(media_info["url_cache"]),
+ url_cache=bool(media_info.url_cache),
server_name=None,
)
@@ -152,7 +152,7 @@ class ThumbnailResource(RestServlet):
if not media_info:
respond_404(request)
return
- if media_info["quarantined_by"]:
+ if media_info.quarantined_by:
logger.info("Media is quarantined")
respond_404(request)
return
@@ -168,7 +168,7 @@ class ThumbnailResource(RestServlet):
file_info = FileInfo(
server_name=None,
file_id=media_id,
- url_cache=media_info["url_cache"],
+ url_cache=bool(media_info.url_cache),
thumbnail=info,
)
@@ -188,7 +188,7 @@ class ThumbnailResource(RestServlet):
desired_height,
desired_method,
desired_type,
- url_cache=bool(media_info["url_cache"]),
+ url_cache=bool(media_info.url_cache),
)
if file_path:
@@ -213,7 +213,7 @@ class ThumbnailResource(RestServlet):
server_name, media_id
)
- file_id = media_info["filesystem_id"]
+ file_id = media_info.filesystem_id
for info in thumbnail_infos:
t_w = info.width == desired_width
@@ -224,7 +224,7 @@ class ThumbnailResource(RestServlet):
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
server_name=server_name,
- file_id=media_info["filesystem_id"],
+ file_id=file_id,
thumbnail=info,
)
@@ -280,7 +280,7 @@ class ThumbnailResource(RestServlet):
m_type,
thumbnail_infos,
media_id,
- media_info["filesystem_id"],
+ media_info.filesystem_id,
url_cache=False,
server_name=server_name,
)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 05775425b7..92b3c77b6d 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -1660,7 +1660,7 @@ class DatabasePool:
retcols: Collection[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one",
- ) -> Dict[str, Any]:
+ ) -> Tuple[Any, ...]:
...
@overload
@@ -1671,7 +1671,7 @@ class DatabasePool:
retcols: Collection[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one",
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[Tuple[Any, ...]]:
...
async def simple_select_one(
@@ -1681,7 +1681,7 @@ class DatabasePool:
retcols: Collection[str],
allow_none: bool = False,
desc: str = "simple_select_one",
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
@@ -2190,7 +2190,7 @@ class DatabasePool:
keyvalues: Dict[str, Any],
retcols: Collection[str],
allow_none: bool = False,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[Tuple[Any, ...]]:
select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
if keyvalues:
@@ -2208,7 +2208,7 @@ class DatabasePool:
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
- return dict(zip(retcols, row))
+ return row
async def simple_delete_one(
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index d7482a1f4e..07f9b65af3 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -747,8 +747,16 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
# Invalidate the cache for any ignored users which were added or removed.
- for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
- self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+ self._invalidate_cache_and_stream_bulk(
+ txn,
+ self.ignored_by,
+ [
+ (ignored_user_id,)
+ for ignored_user_id in (
+ previously_ignored_users ^ currently_ignored_users
+ )
+ ],
+ )
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
async def remove_account_data_for_user(
@@ -824,10 +832,14 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
# Invalidate the cache for ignored users which were removed.
- for ignored_user_id in previously_ignored_users:
- self._invalidate_cache_and_stream(
- txn, self.ignored_by, (ignored_user_id,)
- )
+ self._invalidate_cache_and_stream_bulk(
+ txn,
+ self.ignored_by,
+ [
+ (ignored_user_id,)
+ for ignored_user_id in previously_ignored_users
+ ],
+ )
# Invalidate for this user the cache tracking ignored users.
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 4d0470ffd9..d7232f566b 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -483,6 +483,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
+ def _invalidate_cache_and_stream_bulk(
+ self,
+ txn: LoggingTransaction,
+ cache_func: CachedFunction,
+ key_tuples: Collection[Tuple[Any, ...]],
+ ) -> None:
+ """A bulk version of _invalidate_cache_and_stream.
+
+ Locally invalidate every key-tuple in `key_tuples`, then emit invalidations
+ for each key-tuple over replication.
+
+ This implementation is more efficient than a loop which repeatedly calls the
+ non-bulk version.
+ """
+ if not key_tuples:
+ return
+
+ for keys in key_tuples:
+ txn.call_after(cache_func.invalidate, keys)
+
+ self._send_invalidation_to_replication_bulk(
+ txn, cache_func.__name__, key_tuples
+ )
+
def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None:
@@ -564,10 +588,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine):
assert self._cache_id_gen is not None
- # get_next() returns a context manager which is designed to wrap
- # the transaction. However, we want to only get an ID when we want
- # to use it, here, so we need to call __enter__ manually, and have
- # __exit__ called after the transaction finishes.
stream_id = self._cache_id_gen.get_next_txn(txn)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
@@ -586,6 +606,53 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
},
)
+ def _send_invalidation_to_replication_bulk(
+ self,
+ txn: LoggingTransaction,
+ cache_name: str,
+ key_tuples: Collection[Tuple[Any, ...]],
+ ) -> None:
+ """Announce the invalidation of multiple (but not all) cache entries.
+
+ This is more efficient than repeated calls to the non-bulk version. It should
+ NOT be used to invalidating the entire cache: use
+ `_send_invalidation_to_replication` with keys=None.
+
+ Note that this does *not* invalidate the cache locally.
+
+ Args:
+ txn
+ cache_name
+ key_tuples: Key-tuples to invalidate. Assumed to be non-empty.
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ assert self._cache_id_gen is not None
+
+ stream_ids = self._cache_id_gen.get_next_mult_txn(txn, len(key_tuples))
+ ts = self._clock.time_msec()
+ txn.call_after(self.hs.get_notifier().on_new_replication_data)
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="cache_invalidation_stream_by_instance",
+ keys=(
+ "stream_id",
+ "instance_name",
+ "cache_func",
+ "keys",
+ "invalidation_ts",
+ ),
+ values=[
+ # We convert key_tuples to a list here because psycopg2 serialises
+ # lists as pq arrrays, but serialises tuples as "composite types".
+ # (We need an array because the `keys` column has type `[]text`.)
+ # See:
+ # https://www.psycopg.org/docs/usage.html#adapt-list
+ # https://www.psycopg.org/docs/usage.html#adapt-tuple
+ (stream_id, self._instance_name, cache_name, list(key_tuple), ts)
+ for stream_id, key_tuple in zip(stream_ids, key_tuples)
+ ],
+ )
+
def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
if self._cache_id_gen:
return self._cache_id_gen.get_current_token_for_writer(instance_name)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 04d12a876c..775abbac79 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -255,33 +255,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
A dict containing the device information, or `None` if the device does not
exist.
"""
- return await self.db_pool.simple_select_one(
- table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
- retcols=("user_id", "device_id", "display_name"),
- desc="get_device",
- allow_none=True,
- )
-
- async def get_device_opt(
- self, user_id: str, device_id: str
- ) -> Optional[Dict[str, Any]]:
- """Retrieve a device. Only returns devices that are not marked as
- hidden.
-
- Args:
- user_id: The ID of the user which owns the device
- device_id: The ID of the device to retrieve
- Returns:
- A dict containing the device information, or None if the device does not exist.
- """
- return await self.db_pool.simple_select_one(
+ row = await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
allow_none=True,
)
+ if row is None:
+ return None
+ return {"user_id": row[0], "device_id": row[1], "display_name": row[2]}
async def get_devices_by_user(
self, user_id: str
@@ -1221,9 +1204,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
retcols=["device_id", "device_data"],
allow_none=True,
)
- return (
- (row["device_id"], json_decoder.decode(row["device_data"])) if row else None
- )
+ return (row[0], json_decoder.decode(row[1])) if row else None
def _store_dehydrated_device_txn(
self,
@@ -2326,13 +2307,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
`FALSE` have not been converted.
"""
- row = await self.db_pool.simple_select_one(
- table="device_lists_changes_converted_stream_position",
- keyvalues={},
- retcols=["stream_id", "room_id"],
- desc="get_device_change_last_converted_pos",
+ return cast(
+ Tuple[int, str],
+ await self.db_pool.simple_select_one(
+ table="device_lists_changes_converted_stream_position",
+ keyvalues={},
+ retcols=["stream_id", "room_id"],
+ desc="get_device_change_last_converted_pos",
+ ),
)
- return row["stream_id"], row["room_id"]
async def set_device_change_last_converted_pos(
self,
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index ad904a26a6..fae23c3407 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -506,19 +506,26 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
# it isn't there.
raise StoreError(404, "No backup with that version exists")
- result = self.db_pool.simple_select_one_txn(
- txn,
- table="e2e_room_keys_versions",
- keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
- retcols=("version", "algorithm", "auth_data", "etag"),
- allow_none=False,
+ row = cast(
+ Tuple[int, str, str, Optional[int]],
+ self.db_pool.simple_select_one_txn(
+ txn,
+ table="e2e_room_keys_versions",
+ keyvalues={
+ "user_id": user_id,
+ "version": this_version,
+ "deleted": 0,
+ },
+ retcols=("version", "algorithm", "auth_data", "etag"),
+ allow_none=False,
+ ),
)
- assert result is not None # see comment on `simple_select_one_txn`
- result["auth_data"] = db_to_json(result["auth_data"])
- result["version"] = str(result["version"])
- if result["etag"] is None:
- result["etag"] = 0
- return result
+ return {
+ "auth_data": db_to_json(row[2]),
+ "version": str(row[0]),
+ "algorithm": row[1],
+ "etag": 0 if row[3] is None else row[3],
+ }
return await self.db_pool.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4f96ac25c7..8cb61eaee3 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1237,13 +1237,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
-
- if (user_id, device_id) in seen_user_device:
- continue
seen_user_device.add((user_id, device_id))
- self._invalidate_cache_and_stream(
- txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
- )
+
+ self._invalidate_cache_and_stream_bulk(
+ txn, self.get_e2e_unused_fallback_key_types, seen_user_device
+ )
return results
@@ -1268,9 +1266,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if row is None:
continue
- key_id = row["key_id"]
- key_json = row["key_json"]
- used = row["used"]
+ key_id, key_json, used = row
# Mark fallback key as used if not already.
if not used and mark_as_used:
@@ -1376,14 +1372,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
)
- seen_user_device: Set[Tuple[str, str]] = set()
- for user_id, device_id, _, _, _ in otk_rows:
- if (user_id, device_id) in seen_user_device:
- continue
- seen_user_device.add((user_id, device_id))
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
+ seen_user_device = {
+ (user_id, device_id) for user_id, device_id, _, _, _ in otk_rows
+ }
+ self._invalidate_cache_and_stream_bulk(
+ txn,
+ self.count_e2e_one_time_keys,
+ seen_user_device,
+ )
return otk_rows
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index b8bbd1eccd..98556a0523 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -193,7 +193,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined]
- if room["has_auth_chain_index"]:
+ # If the room has an auth chain index.
+ if room[1]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_ids_chains",
@@ -410,7 +411,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined]
- if room["has_auth_chain_index"]:
+ # If the room has an auth chain index.
+ if room[1]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_difference_chains",
@@ -1436,24 +1438,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
if event_lookup_result is not None:
+ event_type, depth, stream_ordering = event_lookup_result
logger.debug(
"_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
room_id,
seed_event_id,
- event_lookup_result["depth"],
- event_lookup_result["stream_ordering"],
- event_lookup_result["type"],
+ depth,
+ stream_ordering,
+ event_type,
)
- if event_lookup_result["depth"]:
- queue.put(
- (
- -event_lookup_result["depth"],
- -event_lookup_result["stream_ordering"],
- seed_event_id,
- event_lookup_result["type"],
- )
- )
+ if depth:
+ queue.put((-depth, -stream_ordering, seed_event_id, event_type))
while not queue.empty() and len(event_id_results) < limit:
try:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 7c34bde3e5..5207cc0f4e 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1934,8 +1934,7 @@ class PersistEventsStore:
if row is None:
return
- redacted_relates_to = row["relates_to_id"]
- rel_type = row["relation_type"]
+ redacted_relates_to, rel_type = row
self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 0061805150..9c46c5d7bd 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1222,14 +1222,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
# Iterate the parent IDs and invalidate caches.
- for parent_id in {r[1] for r in relations_to_insert}:
- cache_tuple = (parent_id,)
- self._invalidate_cache_and_stream( # type: ignore[attr-defined]
- txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined]
- )
- self._invalidate_cache_and_stream( # type: ignore[attr-defined]
- txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined]
- )
+ cache_tuples = {(r[1],) for r in relations_to_insert}
+ self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
+ txn, self.get_relations_for_event, cache_tuples # type: ignore[attr-defined]
+ )
+ self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
+ txn, self.get_thread_summary, cache_tuples # type: ignore[attr-defined]
+ )
if results:
latest_event_id = results[-1][0]
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 5bf864c1fb..4e63a16fa2 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1998,7 +1998,7 @@ class EventsWorkerStore(SQLBaseStore):
if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,))
- return int(res["topological_ordering"]), int(res["stream_ordering"])
+ return int(res[0]), int(res[1])
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index ce88772f9e..c700872fdc 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -107,13 +107,16 @@ class KeyStore(CacheInvalidationWorkerStore):
# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
- for key_id in verify_keys:
- self._invalidate_cache_and_stream(
- txn, self._get_server_keys_json, ((server_name, key_id),)
- )
- self._invalidate_cache_and_stream(
- txn, self.get_server_key_json_for_remote, (server_name, key_id)
- )
+ self._invalidate_cache_and_stream_bulk(
+ txn,
+ self._get_server_keys_json,
+ [((server_name, key_id),) for key_id in verify_keys],
+ )
+ self._invalidate_cache_and_stream_bulk(
+ txn,
+ self.get_server_key_json_for_remote,
+ [(server_name, key_id) for key_id in verify_keys],
+ )
await self.db_pool.runInteraction(
"store_server_keys_response", store_server_keys_response_txn
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index c8d7c9fd32..3f80a64dc5 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -15,9 +15,7 @@
from enum import Enum
from typing import (
TYPE_CHECKING,
- Any,
Collection,
- Dict,
Iterable,
List,
Optional,
@@ -54,11 +52,32 @@ class LocalMedia:
media_length: int
upload_name: str
created_ts: int
+ url_cache: Optional[str]
last_access_ts: int
quarantined_by: Optional[str]
safe_from_quarantine: bool
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RemoteMedia:
+ media_origin: str
+ media_id: str
+ media_type: str
+ media_length: int
+ upload_name: Optional[str]
+ filesystem_id: str
+ created_ts: int
+ last_access_ts: int
+ quarantined_by: Optional[str]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UrlCache:
+ response_code: int
+ expires_ts: int
+ og: Union[str, bytes]
+
+
class MediaSortOrder(Enum):
"""
Enum to define the sorting method used when returning media with
@@ -165,13 +184,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
super().__init__(database, db_conn, hs)
self.server_name: str = hs.hostname
- async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
+ async def get_local_media(self, media_id: str) -> Optional[LocalMedia]:
"""Get the metadata for a local piece of media
Returns:
None if the media_id doesn't exist.
"""
- return await self.db_pool.simple_select_one(
+ row = await self.db_pool.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@@ -181,11 +200,25 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"created_ts",
"quarantined_by",
"url_cache",
+ "last_access_ts",
"safe_from_quarantine",
),
allow_none=True,
desc="get_local_media",
)
+ if row is None:
+ return None
+ return LocalMedia(
+ media_id=media_id,
+ media_type=row[0],
+ media_length=row[1],
+ upload_name=row[2],
+ created_ts=row[3],
+ quarantined_by=row[4],
+ url_cache=row[5],
+ last_access_ts=row[6],
+ safe_from_quarantine=row[7],
+ )
async def get_local_media_by_user_paginate(
self,
@@ -236,6 +269,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length,
upload_name,
created_ts,
+ url_cache,
last_access_ts,
quarantined_by,
safe_from_quarantine
@@ -257,9 +291,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length=row[2],
upload_name=row[3],
created_ts=row[4],
- last_access_ts=row[5],
- quarantined_by=row[6],
- safe_from_quarantine=bool(row[7]),
+ url_cache=row[5],
+ last_access_ts=row[6],
+ quarantined_by=row[7],
+ safe_from_quarantine=bool(row[8]),
)
for row in txn
]
@@ -390,51 +425,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="mark_local_media_as_safe",
)
- async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
+ async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
None if the URL isn't cached.
"""
- def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
+ def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]:
# get the most recently cached result (relative to the given ts)
- sql = (
- "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
- " FROM local_media_repository_url_cache"
- " WHERE url = ? AND download_ts <= ?"
- " ORDER BY download_ts DESC LIMIT 1"
- )
+ sql = """
+ SELECT response_code, expires_ts, og
+ FROM local_media_repository_url_cache
+ WHERE url = ? AND download_ts <= ?
+ ORDER BY download_ts DESC LIMIT 1
+ """
txn.execute(sql, (url, ts))
row = txn.fetchone()
if not row:
# ...or if we've requested a timestamp older than the oldest
# copy in the cache, return the oldest copy (if any)
- sql = (
- "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
- " FROM local_media_repository_url_cache"
- " WHERE url = ? AND download_ts > ?"
- " ORDER BY download_ts ASC LIMIT 1"
- )
+ sql = """
+ SELECT response_code, expires_ts, og
+ FROM local_media_repository_url_cache
+ WHERE url = ? AND download_ts > ?
+ ORDER BY download_ts ASC LIMIT 1
+ """
txn.execute(sql, (url, ts))
row = txn.fetchone()
if not row:
return None
- return dict(
- zip(
- (
- "response_code",
- "etag",
- "expires_ts",
- "og",
- "media_id",
- "download_ts",
- ),
- row,
- )
- )
+ return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2])
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
@@ -444,7 +467,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
response_code: int,
etag: Optional[str],
expires_ts: int,
- og: Optional[str],
+ og: str,
media_id: str,
download_ts: int,
) -> None:
@@ -510,8 +533,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_cached_remote_media(
self, origin: str, media_id: str
- ) -> Optional[Dict[str, Any]]:
- return await self.db_pool.simple_select_one(
+ ) -> Optional[RemoteMedia]:
+ row = await self.db_pool.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
@@ -520,11 +543,25 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"upload_name",
"created_ts",
"filesystem_id",
+ "last_access_ts",
"quarantined_by",
),
allow_none=True,
desc="get_cached_remote_media",
)
+ if row is None:
+ return row
+ return RemoteMedia(
+ media_origin=origin,
+ media_id=media_id,
+ media_type=row[0],
+ media_length=row[1],
+ upload_name=row[2],
+ created_ts=row[3],
+ filesystem_id=row[4],
+ last_access_ts=row[5],
+ quarantined_by=row[6],
+ )
async def store_cached_remote_media(
self,
@@ -623,10 +660,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
t_width: int,
t_height: int,
t_type: str,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[ThumbnailInfo]:
"""Fetch the thumbnail info of given width, height and type."""
- return await self.db_pool.simple_select_one(
+ row = await self.db_pool.simple_select_one(
table="remote_media_cache_thumbnails",
keyvalues={
"media_origin": origin,
@@ -641,11 +678,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
- "filesystem_id",
),
allow_none=True,
desc="get_remote_media_thumbnail",
)
+ if row is None:
+ return None
+ return ThumbnailInfo(
+ width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
+ )
@trace
async def store_remote_media_thumbnail(
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 3b444d2d07..0198bb09d2 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -363,10 +363,11 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
# for their user ID.
value_values=[(presence_stream_id,) for _ in user_ids],
)
- for user_id in user_ids:
- self._invalidate_cache_and_stream(
- txn, self._get_full_presence_stream_token_for_user, (user_id,)
- )
+ self._invalidate_cache_and_stream_bulk(
+ txn,
+ self._get_full_presence_stream_token_for_user,
+ [(user_id,) for user_id in user_ids],
+ )
return await self.db_pool.runInteraction(
"add_users_to_send_full_presence_to", _add_users_to_send_full_presence_to
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 3ba9cc8853..7ed111f632 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -13,7 +13,6 @@
# limitations under the License.
from typing import TYPE_CHECKING, Optional
-from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -138,23 +137,18 @@ class ProfileWorkerStore(SQLBaseStore):
return 50
async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
- try:
- profile = await self.db_pool.simple_select_one(
- table="profiles",
- keyvalues={"full_user_id": user_id.to_string()},
- retcols=("displayname", "avatar_url"),
- desc="get_profileinfo",
- )
- except StoreError as e:
- if e.code == 404:
- # no match
- return ProfileInfo(None, None)
- else:
- raise
-
- return ProfileInfo(
- avatar_url=profile["avatar_url"], display_name=profile["displayname"]
+ profile = await self.db_pool.simple_select_one(
+ table="profiles",
+ keyvalues={"full_user_id": user_id.to_string()},
+ retcols=("displayname", "avatar_url"),
+ desc="get_profileinfo",
+ allow_none=True,
)
+ if profile is None:
+ # no match
+ return ProfileInfo(None, None)
+
+ return ProfileInfo(avatar_url=profile[1], display_name=profile[0])
async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 1e11bf2706..c3b3e2baaf 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -295,19 +295,28 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# so make sure to keep this actually last.
txn.execute("DROP TABLE events_to_purge")
- for event_id, should_delete in event_rows:
- self._invalidate_cache_and_stream(
- txn, self._get_state_group_for_event, (event_id,)
- )
+ self._invalidate_cache_and_stream_bulk(
+ txn,
+ self._get_state_group_for_event,
+ [(event_id,) for event_id, _ in event_rows],
+ )
- # XXX: This is racy, since have_seen_events could be called between the
- # transaction completing and the invalidation running. On the other hand,
- # that's no different to calling `have_seen_events` just before the
- # event is deleted from the database.
+ # XXX: This is racy, since have_seen_events could be called between the
+ # transaction completing and the invalidation running. On the other hand,
+ # that's no different to calling `have_seen_events` just before the
+ # event is deleted from the database.
+ self._invalidate_cache_and_stream_bulk(
+ txn,
+ self.have_seen_event,
+ [
+ (room_id, event_id)
+ for event_id, should_delete in event_rows
+ if should_delete
+ ],
+ )
+
+ for event_id, should_delete in event_rows:
if should_delete:
- self._invalidate_cache_and_stream(
- txn, self.have_seen_event, (room_id, event_id)
- )
self.invalidate_get_event_cache_after_txn(txn, event_id)
logger.info("[purge] done")
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 37135d431d..f72a23c584 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -468,8 +468,7 @@ class PushRuleStore(PushRulesWorkerStore):
"before/after rule not found: %s" % (relative_to_rule,)
)
- base_priority_class = res["priority_class"]
- base_rule_priority = res["priority"]
+ base_priority_class, base_rule_priority = res
if base_priority_class != priority_class:
raise InconsistentRuleException(
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 56e8eb16a8..3484ce9ef9 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -701,8 +701,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
allow_none=True,
)
- stream_ordering = int(res["stream_ordering"]) if res else None
- rx_ts = res["received_ts"] if res else 0
+ stream_ordering = int(res[0]) if res else None
+ rx_ts = res[1] if res else 0
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 933d76e905..2c3f30e2eb 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
account timestamp as milliseconds since the epoch. None if the account
has not been renewed using the current token yet.
"""
- ret_dict = await self.db_pool.simple_select_one(
- table="account_validity",
- keyvalues={"renewal_token": renewal_token},
- retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
- desc="get_user_from_renewal_token",
- )
-
- return (
- ret_dict["user_id"],
- ret_dict["expiration_ts_ms"],
- ret_dict["token_used_ts_ms"],
+ return cast(
+ Tuple[str, int, Optional[int]],
+ await self.db_pool.simple_select_one(
+ table="account_validity",
+ keyvalues={"renewal_token": renewal_token},
+ retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
+ desc="get_user_from_renewal_token",
+ ),
)
async def get_renewal_token_for_user(self, user_id: str) -> str:
@@ -564,16 +561,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
updatevalues={"shadow_banned": shadow_banned},
)
# In order for this to apply immediately, clear the cache for this user.
- tokens = self.db_pool.simple_select_onecol_txn(
+ tokens = self.db_pool.simple_select_list_txn(
txn,
table="access_tokens",
keyvalues={"user_id": user_id},
- retcol="token",
+ retcols=("token",),
+ )
+ self._invalidate_cache_and_stream_bulk(
+ txn, self.get_user_by_access_token, tokens
)
- for token in tokens:
- self._invalidate_cache_and_stream(
- txn, self.get_user_by_access_token, (token,)
- )
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
@@ -989,16 +985,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns:
user id, or None if no user id/threepid mapping exists
"""
- ret = self.db_pool.simple_select_one_txn(
+ return self.db_pool.simple_select_one_onecol_txn(
txn,
"user_threepids",
{"medium": medium, "address": address},
- ["user_id"],
+ "user_id",
True,
)
- if ret:
- return ret["user_id"]
- return None
async def user_add_threepid(
self,
@@ -1435,16 +1428,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
if res is None:
return False
+ uses_allowed, pending, completed, expiry_time = res
+
# Check if the token has expired
now = self._clock.time_msec()
- if res["expiry_time"] and res["expiry_time"] < now:
+ if expiry_time and expiry_time < now:
return False
# Check if the token has been used up
- if (
- res["uses_allowed"]
- and res["pending"] + res["completed"] >= res["uses_allowed"]
- ):
+ if uses_allowed and pending + completed >= uses_allowed:
return False
# Otherwise, the token is valid
@@ -1490,8 +1482,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Override type because the return type is only optional if
# allow_none is True, and we don't want mypy throwing errors
# about None not being indexable.
- res = cast(
- Dict[str, Any],
+ pending, completed = cast(
+ Tuple[int, int],
self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
@@ -1506,8 +1498,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"registration_tokens",
keyvalues={"token": token},
updatevalues={
- "completed": res["completed"] + 1,
- "pending": res["pending"] - 1,
+ "completed": completed + 1,
+ "pending": pending - 1,
},
)
@@ -1585,13 +1577,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns:
A dict, or None if token doesn't exist.
"""
- return await self.db_pool.simple_select_one(
+ row = await self.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
allow_none=True,
desc="get_one_registration_token",
)
+ if row is None:
+ return None
+ return {
+ "token": row[0],
+ "uses_allowed": row[1],
+ "pending": row[2],
+ "completed": row[3],
+ "expiry_time": row[4],
+ }
async def generate_registration_token(
self, length: int, chars: str
@@ -1714,7 +1715,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return None
# Get all info about the token so it can be sent in the response
- return self.db_pool.simple_select_one_txn(
+ result = self.db_pool.simple_select_one_txn(
txn,
"registration_tokens",
keyvalues={"token": token},
@@ -1728,6 +1729,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
allow_none=True,
)
+ if result is None:
+ return result
+
+ return {
+ "token": result[0],
+ "uses_allowed": result[1],
+ "pending": result[2],
+ "completed": result[3],
+ "expiry_time": result[4],
+ }
+
return await self.db_pool.runInteraction(
"update_registration_token", _update_registration_token_txn
)
@@ -1939,11 +1951,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
keyvalues={"token": token},
updatevalues={"used_ts": ts},
)
- user_id = values["user_id"]
- expiry_ts = values["expiry_ts"]
- used_ts = values["used_ts"]
- auth_provider_id = values["auth_provider_id"]
- auth_provider_session_id = values["auth_provider_session_id"]
+ (
+ user_id,
+ expiry_ts,
+ used_ts,
+ auth_provider_id,
+ auth_provider_session_id,
+ ) = values
# Token was already used
if used_ts is not None:
@@ -2668,10 +2682,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
)
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
- for token, _, _ in tokens_and_devices:
- self._invalidate_cache_and_stream(
- txn, self.get_user_by_access_token, (token,)
- )
+ self._invalidate_cache_and_stream_bulk(
+ txn,
+ self.get_user_by_access_token,
+ [(token,) for token, _, _ in tokens_and_devices],
+ )
txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
@@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
# reason, the next check is on the client secret, which is NOT NULL,
# so we don't have to worry about the client secret matching by
# accident.
- row = {"client_secret": None, "validated_at": None}
+ row = None, None
else:
raise ThreepidValidationError("Unknown session_id")
- retrieved_client_secret = row["client_secret"]
- validated_at = row["validated_at"]
+ retrieved_client_secret, validated_at = row
row = self.db_pool.simple_select_one_txn(
txn,
@@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
raise ThreepidValidationError(
"Validation token not found or has expired"
)
- expires = row["expires"]
- next_link = row["next_link"]
+ expires, next_link = row
if retrieved_client_secret != client_secret:
raise ThreepidValidationError(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index afb880532e..ef26d5d9d3 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -213,21 +213,31 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
- async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
+ async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]:
"""Retrieve a room.
Args:
room_id: The ID of the room to retrieve.
Returns:
- A dict containing the room information, or None if the room is unknown.
+ A tuple containing the room information:
+ * True if the room is public
+ * True if the room has an auth chain index
+
+ or None if the room is unknown.
"""
- return await self.db_pool.simple_select_one(
- table="rooms",
- keyvalues={"room_id": room_id},
- retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
- desc="get_room",
- allow_none=True,
+ row = cast(
+ Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]],
+ await self.db_pool.simple_select_one(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcols=("is_public", "has_auth_chain_index"),
+ desc="get_room",
+ allow_none=True,
+ ),
)
+ if row is None:
+ return row
+ return bool(row[0]), bool(row[1])
async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
"""Retrieve room with statistics.
@@ -794,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
)
if row:
- return RatelimitOverride(
- messages_per_second=row["messages_per_second"],
- burst_count=row["burst_count"],
- )
+ return RatelimitOverride(messages_per_second=row[0], burst_count=row[1])
else:
return None
@@ -1371,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
join.
"""
- result = await self.db_pool.simple_select_one(
- table="partial_state_rooms",
- keyvalues={"room_id": room_id},
- retcols=("join_event_id", "device_lists_stream_id"),
- desc="get_join_event_id_for_partial_state",
+ return cast(
+ Tuple[str, int],
+ await self.db_pool.simple_select_one(
+ table="partial_state_rooms",
+ keyvalues={"room_id": room_id},
+ retcols=("join_event_id", "device_lists_stream_id"),
+ desc="get_join_event_id_for_partial_state",
+ ),
)
- return result["join_event_id"], result["device_lists_stream_id"]
def get_un_partial_stated_rooms_token(self, instance_name: str) -> int:
return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 1ed7f2d0ef..60d4a9ef30 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -559,17 +559,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
"non-local user %s" % (user_id,),
)
- results_dict = await self.db_pool.simple_select_one(
- "local_current_membership",
- {"room_id": room_id, "user_id": user_id},
- ("membership", "event_id"),
- allow_none=True,
- desc="get_local_current_membership_for_user_in_room",
+ results = cast(
+ Optional[Tuple[str, str]],
+ await self.db_pool.simple_select_one(
+ "local_current_membership",
+ {"room_id": room_id, "user_id": user_id},
+ ("membership", "event_id"),
+ allow_none=True,
+ desc="get_local_current_membership_for_user_in_room",
+ ),
)
- if not results_dict:
+ if not results:
return None, None
- return results_dict.get("membership"), results_dict.get("event_id")
+ return results
@cached(max_entries=500000, iterable=True)
async def get_rooms_for_user_with_stream_ordering(
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 2225f8272d..563c275a2c 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1014,9 +1014,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="get_position_for_event",
)
- return PersistedEventPosition(
- row["instance_name"] or "master", row["stream_ordering"]
- )
+ return PersistedEventPosition(row[1] or "master", row[0])
async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
"""The stream token for an event
@@ -1033,9 +1031,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
)
- return RoomStreamToken(
- topological=row["topological_ordering"], stream=row["stream_ordering"]
- )
+ return RoomStreamToken(topological=row[1], stream=row[0])
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
"""Gets the topological token in a room after or at the given stream
@@ -1180,26 +1176,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
- results = self.db_pool.simple_select_one_txn(
- txn,
- "events",
- keyvalues={"event_id": event_id, "room_id": room_id},
- retcols=["stream_ordering", "topological_ordering"],
+ stream_ordering, topological_ordering = cast(
+ Tuple[int, int],
+ self.db_pool.simple_select_one_txn(
+ txn,
+ "events",
+ keyvalues={"event_id": event_id, "room_id": room_id},
+ retcols=["stream_ordering", "topological_ordering"],
+ ),
)
- # This cannot happen as `allow_none=False`.
- assert results is not None
-
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
- topological=results["topological_ordering"] - 1,
- stream=results["stream_ordering"],
+ topological=topological_ordering - 1, stream=stream_ordering
)
after_token = RoomStreamToken(
- topological=results["topological_ordering"],
- stream=results["stream_ordering"],
+ topological=topological_ordering, stream=stream_ordering
)
rows, start_token = self._paginate_room_events_txn(
diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py
index 5555b53575..64543b4d61 100644
--- a/synapse/storage/databases/main/task_scheduler.py
+++ b/synapse/storage/databases/main/task_scheduler.py
@@ -183,39 +183,27 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
Returns: the task if available, `None` otherwise
"""
- row = await self.db_pool.simple_select_one(
- table="scheduled_tasks",
- keyvalues={"id": id},
- retcols=(
- "id",
- "action",
- "status",
- "timestamp",
- "resource_id",
- "params",
- "result",
- "error",
+ row = cast(
+ Optional[ScheduledTaskRow],
+ await self.db_pool.simple_select_one(
+ table="scheduled_tasks",
+ keyvalues={"id": id},
+ retcols=(
+ "id",
+ "action",
+ "status",
+ "timestamp",
+ "resource_id",
+ "params",
+ "result",
+ "error",
+ ),
+ allow_none=True,
+ desc="get_scheduled_task",
),
- allow_none=True,
- desc="get_scheduled_task",
)
- return (
- TaskSchedulerWorkerStore._convert_row_to_task(
- (
- row["id"],
- row["action"],
- row["status"],
- row["timestamp"],
- row["resource_id"],
- row["params"],
- row["result"],
- row["error"],
- )
- )
- if row
- else None
- )
+ return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
async def delete_scheduled_task(self, id: str) -> None:
"""Delete a specific task from its id.
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index fecddb4144..2d341affaa 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -118,19 +118,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
txn,
table="received_transactions",
keyvalues={"transaction_id": transaction_id, "origin": origin},
- retcols=(
- "transaction_id",
- "origin",
- "ts",
- "response_code",
- "response_json",
- "has_been_referenced",
- ),
+ retcols=("response_code", "response_json"),
allow_none=True,
)
- if result and result["response_code"]:
- return result["response_code"], db_to_json(result["response_json"])
+ # If the result exists and the response code is non-0.
+ if result and result[0]:
+ return result[0], db_to_json(result[1])
else:
return None
@@ -200,8 +194,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
# check we have a row and retry_last_ts is not null or zero
# (retry_last_ts can't be negative)
- if result and result["retry_last_ts"]:
- return DestinationRetryTimings(**result)
+ if result and result[1]:
+ return DestinationRetryTimings(
+ failure_ts=result[0], retry_last_ts=result[1], retry_interval=result[2]
+ )
else:
return None
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 8ab7c42c4a..5b164fed8e 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -122,9 +122,13 @@ class UIAuthWorkerStore(SQLBaseStore):
desc="get_ui_auth_session",
)
- result["clientdict"] = db_to_json(result["clientdict"])
-
- return UIAuthSessionData(session_id, **result)
+ return UIAuthSessionData(
+ session_id,
+ clientdict=db_to_json(result[0]),
+ uri=result[1],
+ method=result[2],
+ description=result[3],
+ )
async def mark_ui_auth_stage_complete(
self,
@@ -231,18 +235,15 @@ class UIAuthWorkerStore(SQLBaseStore):
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
) -> None:
# Get the current value.
- result = cast(
- Dict[str, Any],
- self.db_pool.simple_select_one_txn(
- txn,
- table="ui_auth_sessions",
- keyvalues={"session_id": session_id},
- retcols=("serverdict",),
- ),
+ result = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="ui_auth_sessions",
+ keyvalues={"session_id": session_id},
+ retcol="serverdict",
)
# Update it and add it back to the database.
- serverdict = db_to_json(result["serverdict"])
+ serverdict = db_to_json(result)
serverdict[key] = value
self.db_pool.simple_update_one_txn(
@@ -265,14 +266,14 @@ class UIAuthWorkerStore(SQLBaseStore):
Raises:
StoreError if the session cannot be found.
"""
- result = await self.db_pool.simple_select_one(
+ result = await self.db_pool.simple_select_one_onecol(
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
- retcols=("serverdict",),
+ retcol="serverdict",
desc="get_ui_auth_session_data",
)
- serverdict = db_to_json(result["serverdict"])
+ serverdict = db_to_json(result)
return serverdict.get(key, default)
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index d4b86ed7a6..93ec06904a 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -20,7 +20,6 @@ from typing import (
Collection,
Iterable,
List,
- Mapping,
Optional,
Sequence,
Set,
@@ -868,13 +867,25 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
- async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
- return await self.db_pool.simple_select_one(
- table="user_directory",
- keyvalues={"user_id": user_id},
- retcols=("display_name", "avatar_url"),
- allow_none=True,
- desc="get_user_in_directory",
+ async def _get_user_in_directory(
+ self, user_id: str
+ ) -> Optional[Tuple[Optional[str], Optional[str]]]:
+ """
+ Fetch the user information in the user directory.
+
+ Returns:
+ None if the user is unknown, otherwise a tuple of display name and
+ avatar URL (both of which may be None).
+ """
+ return cast(
+ Optional[Tuple[Optional[str], Optional[str]]],
+ await self.db_pool.simple_select_one(
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ retcols=("display_name", "avatar_url"),
+ allow_none=True,
+ desc="get_user_in_directory",
+ ),
)
async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None:
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 52d708ad17..76a74b8b8f 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -650,8 +650,8 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
next_id = self._load_next_id_txn(txn)
- txn.call_after(self._mark_id_as_finished, next_id)
- txn.call_on_exception(self._mark_id_as_finished, next_id)
+ txn.call_after(self._mark_ids_as_finished, [next_id])
+ txn.call_on_exception(self._mark_ids_as_finished, [next_id])
txn.call_after(self._notifier.notify_replication)
# Update the `stream_positions` table with newly updated stream
@@ -671,14 +671,50 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return self._return_factor * next_id
- def _mark_id_as_finished(self, next_id: int) -> None:
- """The ID has finished being processed so we should advance the
+ def get_next_mult_txn(self, txn: LoggingTransaction, n: int) -> List[int]:
+ """
+ Usage:
+
+ stream_id = stream_id_gen.get_next_txn(txn)
+ # ... persist event ...
+ """
+
+ # If we have a list of instances that are allowed to write to this
+ # stream, make sure we're in it.
+ if self._writers and self._instance_name not in self._writers:
+ raise Exception("Tried to allocate stream ID on non-writer")
+
+ next_ids = self._load_next_mult_id_txn(txn, n)
+
+ txn.call_after(self._mark_ids_as_finished, next_ids)
+ txn.call_on_exception(self._mark_ids_as_finished, next_ids)
+ txn.call_after(self._notifier.notify_replication)
+
+ # Update the `stream_positions` table with newly updated stream
+ # ID (unless self._writers is not set in which case we don't
+ # bother, as nothing will read it).
+ #
+ # We only do this on the success path so that the persisted current
+ # position points to a persisted row with the correct instance name.
+ if self._writers:
+ txn.call_after(
+ run_as_background_process,
+ "MultiWriterIdGenerator._update_table",
+ self._db.runInteraction,
+ "MultiWriterIdGenerator._update_table",
+ self._update_stream_positions_table_txn,
+ )
+
+ return [self._return_factor * next_id for next_id in next_ids]
+
+ def _mark_ids_as_finished(self, next_ids: List[int]) -> None:
+ """These IDs have finished being processed so we should advance the
current position if possible.
"""
with self._lock:
- self._unfinished_ids.discard(next_id)
- self._finished_ids.add(next_id)
+ self._unfinished_ids.difference_update(next_ids)
+ self._finished_ids.update(next_ids)
new_cur: Optional[int] = None
@@ -727,7 +763,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
curr, new_cur, self._max_position_of_local_instance
)
- self._add_persisted_position(next_id)
+ # TODO Can we call this for just the last position or somehow batch
+ # _add_persisted_position.
+ for next_id in next_ids:
+ self._add_persisted_position(next_id)
def get_current_token(self) -> int:
return self.get_persisted_upto_position()
@@ -933,8 +972,7 @@ class _MultiWriterCtxManager:
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> bool:
- for i in self.stream_ids:
- self.id_gen._mark_id_as_finished(i)
+ self.id_gen._mark_ids_as_finished(self.stream_ids)
self.notifier.notify_replication()
|