diff --git a/changelog.d/8189.doc b/changelog.d/8189.doc
new file mode 100644
index 0000000000..800ff89dc5
--- /dev/null
+++ b/changelog.d/8189.doc
@@ -0,0 +1 @@
+Explain better what GDPR-erased means when deactivating a user.
diff --git a/changelog.d/8201.misc b/changelog.d/8201.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8201.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8207.misc b/changelog.d/8207.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8207.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8222.misc b/changelog.d/8222.misc
new file mode 100644
index 0000000000..979c8b227b
--- /dev/null
+++ b/changelog.d/8222.misc
@@ -0,0 +1 @@
+Refactor queries for device keys and cross-signatures.
diff --git a/changelog.d/8223.bugfix b/changelog.d/8223.bugfix
new file mode 100644
index 0000000000..60655ce3e1
--- /dev/null
+++ b/changelog.d/8223.bugfix
@@ -0,0 +1 @@
+Fixes a longstanding bug where user directory updates could break when unexpected profile data was included in events.
diff --git a/changelog.d/8224.misc b/changelog.d/8224.misc
new file mode 100644
index 0000000000..979c8b227b
--- /dev/null
+++ b/changelog.d/8224.misc
@@ -0,0 +1 @@
+Refactor queries for device keys and cross-signatures.
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index d6e3194cda..e21c78a9c6 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -214,9 +214,11 @@ Deactivate Account
This API deactivates an account. It removes active access tokens, resets the
password, and deletes third-party IDs (to prevent the user requesting a
-password reset). It can also mark the user as GDPR-erased (stopping their data
-from distributed further, and deleting it entirely if there are no other
-references to it).
+password reset).
+
+It can also mark the user as GDPR-erased. This means messages sent by the
+user will still be visible by anyone that was in the room when these messages
+were sent, but hidden from users joining the room afterwards.
The api is::
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index e72a0b9ac0..bb6fa8299a 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -14,18 +14,20 @@
# limitations under the License.
import logging
import urllib
+from typing import TYPE_CHECKING, Optional
from prometheus_client import Counter
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.events.utils import serialize_event
from synapse.http.client import SimpleHttpClient
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache
+if TYPE_CHECKING:
+ from synapse.appservice import ApplicationService
+
logger = logging.getLogger(__name__)
sent_transactions_counter = Counter(
@@ -163,19 +165,20 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_3pe to %s threw exception %s", uri, ex)
return []
- def get_3pe_protocol(self, service, protocol):
+ async def get_3pe_protocol(
+ self, service: "ApplicationService", protocol: str
+ ) -> Optional[JsonDict]:
if service.url is None:
return {}
- @defer.inlineCallbacks
- def _get():
+ async def _get() -> Optional[JsonDict]:
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
urllib.parse.quote(protocol),
)
try:
- info = yield defer.ensureDeferred(self.get_json(uri, {}))
+ info = await self.get_json(uri, {})
if not _is_valid_3pe_metadata(info):
logger.warning(
@@ -196,7 +199,7 @@ class ApplicationServiceApi(SimpleHttpClient):
return None
key = (service.id, protocol)
- return self.protocol_meta_cache.wrap(key, _get)
+ return await self.protocol_meta_cache.wrap(key, _get)
async def push_bulk(self, service, events, txn_id=None):
if service.url is None:
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index adb9dc7c42..5c9579394c 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -262,6 +262,9 @@ class BaseProfileHandler(BaseHandler):
Codes.FORBIDDEN,
)
+ if not isinstance(new_displayname, str):
+ raise SynapseError(400, "Invalid displayname")
+
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
raise SynapseError(
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
@@ -386,6 +389,9 @@ class BaseProfileHandler(BaseHandler):
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
)
+ if not isinstance(new_avatar_url, str):
+ raise SynapseError(400, "Invalid displayname")
+
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
raise SynapseError(
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 521b6d620d..e21f8dbc58 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -234,7 +234,7 @@ class UserDirectoryHandler(StateDeltasHandler):
async def _handle_room_publicity_change(
self, room_id, prev_event_id, event_id, typ
):
- """Handle a room having potentially changed from/to world_readable/publically
+ """Handle a room having potentially changed from/to world_readable/publicly
joinable.
Args:
@@ -388,9 +388,15 @@ class UserDirectoryHandler(StateDeltasHandler):
prev_name = prev_event.content.get("displayname")
new_name = event.content.get("displayname")
+ # If the new name is an unexpected form, do not update the directory.
+ if not isinstance(new_name, str):
+ new_name = prev_name
prev_avatar = prev_event.content.get("avatar_url")
new_avatar = event.content.get("avatar_url")
+ # If the new avatar is an unexpected form, do not update the directory.
+ if not isinstance(new_avatar, str):
+ new_avatar = prev_avatar
if prev_name != new_name or prev_avatar != new_avatar:
await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 0db900fa0e..67a89cd51a 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -433,7 +433,7 @@ class BackgroundUpdater(object):
"background_updates", keyvalues={"update_name": update_name}
)
- def _background_update_progress(self, update_name: str, progress: dict):
+ async def _background_update_progress(self, update_name: str, progress: dict):
"""Update the progress of a background update
Args:
@@ -441,7 +441,7 @@ class BackgroundUpdater(object):
progress: The progress of the update.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"background_update_progress",
self._background_update_progress_txn,
update_name,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 04042a2c98..4436b1a83d 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -16,9 +16,7 @@
import abc
import logging
-from typing import List, Optional, Tuple
-
-from twisted.internet import defer
+from typing import Dict, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
@@ -58,14 +56,16 @@ class AccountDataWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cached()
- def get_account_data_for_user(self, user_id):
+ async def get_account_data_for_user(
+ self, user_id: str
+ ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a user.
Args:
- user_id(str): The user to get the account_data for.
+ user_id: The user to get the account_data for.
Returns:
- A deferred pair of a dict of global account_data and a dict
- mapping from room_id string to per room account_data dicts.
+ A 2-tuple of a dict of global account_data and a dict mapping from
+ room_id string to per room account_data dicts.
"""
def get_account_data_for_user_txn(txn):
@@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return global_account_data, by_room
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)
@@ -120,14 +120,16 @@ class AccountDataWorkerStore(SQLBaseStore):
return None
@cached(num_args=2)
- def get_account_data_for_room(self, user_id, room_id):
+ async def get_account_data_for_room(
+ self, user_id: str, room_id: str
+ ) -> Dict[str, JsonDict]:
"""Get all the client account_data for a user for a room.
Args:
- user_id(str): The user to get the account_data for.
- room_id(str): The room to get the account_data for.
+ user_id: The user to get the account_data for.
+ room_id: The room to get the account_data for.
Returns:
- A deferred dict of the room account_data
+ A dict of the room account_data
"""
def get_account_data_for_room_txn(txn):
@@ -142,21 +144,22 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)
@cached(num_args=3, max_entries=5000)
- def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
+ async def get_account_data_for_room_and_type(
+ self, user_id: str, room_id: str, account_data_type: str
+ ) -> Optional[JsonDict]:
"""Get the client account_data of given type for a user for a room.
Args:
- user_id(str): The user to get the account_data for.
- room_id(str): The room to get the account_data for.
- account_data_type (str): The account data type to get.
+ user_id: The user to get the account_data for.
+ room_id: The room to get the account_data for.
+ account_data_type: The account data type to get.
Returns:
- A deferred of the room account_data for that type, or None if
- there isn't any set.
+ The room account_data for that type, or None if there isn't any set.
"""
def get_account_data_for_room_and_type_txn(txn):
@@ -174,7 +177,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return db_to_json(content_json) if content_json else None
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
@@ -238,12 +241,14 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_updated_room_account_data", get_updated_room_account_data_txn
)
- def get_updated_account_data_for_user(self, user_id, stream_id):
+ async def get_updated_account_data_for_user(
+ self, user_id: str, stream_id: int
+ ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a that's changed for a user
Args:
- user_id(str): The user to get the account_data for.
- stream_id(int): The point in the stream since which to get updates
+ user_id: The user to get the account_data for.
+ stream_id: The point in the stream since which to get updates
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
@@ -277,9 +282,9 @@ class AccountDataWorkerStore(SQLBaseStore):
user_id, int(stream_id)
)
if not changed:
- return defer.succeed(({}, {}))
+ return ({}, {})
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@@ -416,7 +421,7 @@ class AccountDataStore(AccountDataWorkerStore):
return self._account_data_id_gen.get_current_token()
- def _update_max_stream_id(self, next_id: int):
+ async def _update_max_stream_id(self, next_id: int) -> None:
"""Update the max stream_id
Args:
@@ -435,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore):
)
txn.execute(update_max_id_sql, (next_id, next_id))
- return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
+ await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a29157d979..8bedcdbdff 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -256,8 +256,8 @@ class DeviceWorkerStore(SQLBaseStore):
"""
devices = (
await self.db_pool.runInteraction(
- "_get_e2e_device_keys_txn",
- self._get_e2e_device_keys_txn,
+ "get_e2e_device_keys_and_signatures_txn",
+ self._get_e2e_device_keys_and_signatures_txn,
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
@@ -293,17 +293,17 @@ class DeviceWorkerStore(SQLBaseStore):
prev_id = stream_id
if device is not None:
- key_json = device.get("key_json", None)
+ key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
+ if device.signatures:
+ for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
- device_display_name = device.get("device_display_name", None)
+ device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
else:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index fb3b1f94de..449d95f31e 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -17,6 +17,7 @@
import abc
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+import attr
from canonicaljson import encode_canonical_json
from twisted.enterprise.adbapi import Connection
@@ -33,14 +34,31 @@ if TYPE_CHECKING:
from synapse.handlers.e2e_keys import SignatureListItem
+@attr.s
+class DeviceKeyLookupResult:
+ """The type returned by _get_e2e_device_keys_and_signatures_txn"""
+
+ display_name = attr.ib(type=Optional[str])
+
+ # the key data from e2e_device_keys_json. Typically includes fields like
+ # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
+ # key) and "signatures" (a signature of the structure by the ed25519 key)
+ key_json = attr.ib(type=Optional[str])
+
+ # cross-signing sigs
+ signatures = attr.ib(type=Optional[Dict], default=None)
+
+
class EndToEndKeyWorkerStore(SQLBaseStore):
- def get_e2e_device_keys_for_federation_query(self, user_id: str):
+ async def get_e2e_device_keys_for_federation_query(
+ self, user_id: str
+ ) -> Tuple[int, List[JsonDict]]:
"""Get all devices (with any device keys) for a user
Returns:
- Deferred which resolves to (stream_id, devices)
+ (stream_id, devices)
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_e2e_device_keys_for_federation_query",
self._get_e2e_device_keys_for_federation_query_txn,
user_id,
@@ -51,7 +69,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
) -> Tuple[int, List[JsonDict]]:
now_stream_id = self.get_device_stream_token()
- devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
+ devices = self._get_e2e_device_keys_and_signatures_txn(txn, [(user_id, None)])
if devices:
user_devices = devices[user_id]
@@ -59,17 +77,17 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for device_id, device in user_devices.items():
result = {"device_id": device_id}
- key_json = device.get("key_json", None)
+ key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
+ if device.signatures:
+ for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
- device_display_name = device.get("device_display_name", None)
+ device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
@@ -96,7 +114,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return {}
results = await self.db_pool.runInteraction(
- "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list,
+ "get_e2e_device_keys_and_signatures_txn",
+ self._get_e2e_device_keys_and_signatures_txn,
+ query_list,
)
# Build the result structure, un-jsonify the results, and add the
@@ -105,13 +125,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for user_id, device_keys in results.items():
rv[user_id] = {}
for device_id, device_info in device_keys.items():
- r = db_to_json(device_info.pop("key_json"))
+ r = db_to_json(device_info.key_json)
r["unsigned"] = {}
- display_name = device_info["device_display_name"]
+ display_name = device_info.display_name
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
- if "signatures" in device_info:
- for sig_user_id, sigs in device_info["signatures"].items():
+ if device_info.signatures:
+ for sig_user_id, sigs in device_info.signatures.items():
r.setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
@@ -120,9 +140,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return rv
@trace
- def _get_e2e_device_keys_txn(
+ def _get_e2e_device_keys_and_signatures_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
- ):
+ ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices)
@@ -157,7 +177,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
sql = (
"SELECT user_id, device_id, "
- " d.display_name AS device_display_name, "
+ " d.display_name, "
" k.key_json"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
@@ -168,13 +188,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
txn.execute(sql, query_params)
- rows = self.db_pool.cursor_to_dict(txn)
- result = {}
- for row in rows:
+ result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
+ for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices:
- deleted_devices.remove((row["user_id"], row["device_id"]))
- result.setdefault(row["user_id"], {})[row["device_id"]] = row
+ deleted_devices.remove((user_id, device_id))
+ result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
+ display_name, key_json
+ )
if include_deleted_devices:
for user_id, device_id in deleted_devices:
@@ -205,7 +226,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# note that target_device_result will be None for deleted devices.
continue
- target_device_signatures = target_device_result.setdefault("signatures", {})
+ target_device_signatures = target_device_result.signatures
+ if target_device_signatures is None:
+ target_device_signatures = target_device_result.signatures = {}
+
signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
)
@@ -290,10 +314,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
@cached(max_entries=10000)
- def count_e2e_one_time_keys(self, user_id, device_id):
+ async def count_e2e_one_time_keys(
+ self, user_id: str, device_id: str
+ ) -> Dict[str, int]:
""" Count the number of one time keys the server has for a device
Returns:
- Dict mapping from algorithm to number of keys for that algorithm.
+ A mapping from algorithm to number of keys for that algorithm.
"""
def _count_e2e_one_time_keys(txn):
@@ -308,7 +334,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result[algorithm] = key_count
return result
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
@@ -346,7 +372,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
list_name="user_ids",
num_args=1,
)
- def _get_bare_e2e_cross_signing_keys_bulk(
+ async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: List[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
@@ -354,16 +380,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
the signatures for the calling user need to be fetched.
Args:
- user_ids (list[str]): the users whose keys are being requested
+ user_ids: the users whose keys are being requested
Returns:
- dict[str, dict[str, dict]]: mapping from user ID to key type to key
- data. If a user's cross-signing keys were not found, either
- their user ID will not be in the dict, or their user ID will map
- to None.
+ A mapping from user ID to key type to key data. If a user's cross-signing
+ keys were not found, either their user ID will not be in the dict, or
+ their user ID will map to None.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
@@ -586,7 +611,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
- def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+ async def set_e2e_device_keys(
+ self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
+ ) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
@@ -622,12 +649,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"message": "Device keys stored."})
return True
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
- def claim_e2e_one_time_keys(self, query_list):
- """Take a list of one time keys out of the database"""
+ async def claim_e2e_one_time_keys(
+ self, query_list: Iterable[Tuple[str, str, str]]
+ ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+ """Take a list of one time keys out of the database.
+
+ Args:
+ query_list: An iterable of tuples of (user ID, device ID, algorithm).
+
+ Returns:
+ A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+ """
@trace
def _claim_e2e_one_time_keys(txn):
@@ -663,11 +699,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
return result
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
- def delete_e2e_keys_by_device(self, user_id, device_id):
+ async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn):
log_kv(
{
@@ -690,7 +726,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 3919ecad69..86557d5512 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@@ -93,7 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="mark_local_media_as_safe",
)
- def get_url_cache(self, url, ts):
+ async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
None if the URL isn't cached.
@@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
)
- return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
+ return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
@@ -237,12 +237,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_cached_remote_media",
)
- def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+ async def update_cached_last_access_time(
+ self,
+ local_media: Iterable[str],
+ remote_media: Iterable[Tuple[str, str]],
+ time_ms: int,
+ ):
"""Updates the last access time of the given media
Args:
- local_media (iterable[str]): Set of media_ids
- remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+ local_media: Set of media_ids
+ remote_media: Set of (server_name, media_id)
time_ms: Current time in milliseconds
"""
@@ -267,7 +272,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
)
@@ -325,7 +330,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
)
- def delete_remote_media(self, media_origin, media_id):
+ async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
def delete_remote_media_txn(txn):
self.db_pool.simple_delete_txn(
txn,
@@ -338,11 +343,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_remote_media", delete_remote_media_txn
)
- def get_expired_url_cache(self, now_ts):
+ async def get_expired_url_cache(self, now_ts: int) -> List[str]:
sql = (
"SELECT media_id FROM local_media_repository_url_cache"
" WHERE expires_ts < ?"
@@ -354,7 +359,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_expired_url_cache", _get_expired_url_cache_txn
)
@@ -371,7 +376,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"delete_url_cache", _delete_url_cache_txn
)
- def get_url_cache_media_before(self, before_ts):
+ async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
sql = (
"SELECT media_id FROM local_media_repository"
" WHERE created_ts < ? AND url_cache IS NOT NULL"
@@ -383,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn
)
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 7f8d1880e5..f01cf2fd02 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -16,9 +16,10 @@
import logging
import re
from collections import namedtuple
-from typing import List, Optional
+from typing import List, Optional, Set
from synapse.api.errors import SynapseError
+from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@@ -620,7 +621,9 @@ class SearchStore(SearchBackgroundUpdateStore):
"count": count,
}
- def _find_highlights_in_postgres(self, search_query, events):
+ async def _find_highlights_in_postgres(
+ self, search_query: str, events: List[EventBase]
+ ) -> Set[str]:
"""Given a list of events and a search term, return a list of words
that match from the content of the event.
@@ -628,11 +631,11 @@ class SearchStore(SearchBackgroundUpdateStore):
highlight the matching parts.
Args:
- search_query (str)
- events (list): A list of events
+ search_query
+ events: A list of events
Returns:
- deferred : A set of strings.
+ A set of strings.
"""
def f(txn):
@@ -685,7 +688,7 @@ class SearchStore(SearchBackgroundUpdateStore):
return highlight_words
- return self.db_pool.runInteraction("_find_highlights", f)
+ return await self.db_pool.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict):
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index a9f2e93614..c977db042e 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -15,7 +15,7 @@
import logging
import re
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Iterable, Optional, Tuple
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool
@@ -365,10 +365,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return False
- def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
+ async def update_profile_in_user_dir(
+ self, user_id: str, display_name: str, avatar_url: str
+ ) -> None:
"""
Update or add a user's profile in the user directory.
"""
+ # If the display name or avatar URL are unexpected types, overwrite them.
+ if not isinstance(display_name, str):
+ display_name = None
+ if not isinstance(avatar_url, str):
+ avatar_url = None
def _update_profile_in_user_dir_txn(txn):
new_entry = self.db_pool.simple_upsert_txn(
@@ -458,17 +465,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
- def add_users_who_share_private_room(self, room_id, user_id_tuples):
+ async def add_users_who_share_private_room(
+ self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
+ ) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
Args:
- room_id (str)
- user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+ room_id
+ user_id_tuples: iterable of 2-tuple of user IDs.
"""
def _add_users_who_share_room_txn(txn):
@@ -484,17 +493,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None,
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
)
- def add_users_in_public_rooms(self, room_id, user_ids):
+ async def add_users_in_public_rooms(
+ self, room_id: str, user_ids: Iterable[str]
+ ) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
Args:
- room_id (str)
- user_ids (list[str])
+ room_id
+ user_ids
"""
def _add_users_in_public_rooms_txn(txn):
@@ -508,11 +519,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None,
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
)
- def delete_all_from_user_dir(self):
+ async def delete_all_from_user_dir(self) -> None:
"""Delete the entire user directory
"""
@@ -523,7 +534,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@@ -555,7 +566,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(UserDirectoryStore, self).__init__(database, db_conn, hs)
- def remove_from_user_dir(self, user_id):
+ async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):
self.db_pool.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
@@ -578,7 +589,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"remove_from_user_dir", _remove_from_user_dir_txn
)
@@ -605,14 +616,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return user_ids
- def remove_user_who_share_room(self, user_id, room_id):
+ async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
"""
Deletes entries in the users_who_share_*_rooms table. The first
user should be a local user.
Args:
- user_id (str)
- room_id (str)
+ user_id
+ room_id
"""
def _remove_user_who_share_room_txn(txn):
@@ -632,7 +643,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
keyvalues={"user_id": user_id, "room_id": room_id},
)
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
|