diff options
Diffstat (limited to 'synapse/storage')
19 files changed, 241 insertions, 451 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 2193d8fdc5..cf039e7f7d 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -18,13 +18,12 @@ import abc import logging from typing import List, Tuple -from canonicaljson import json - from twisted.internet import defer from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -327,7 +326,7 @@ class AccountDataStore(AccountDataWorkerStore): Returns: A deferred that completes once the account_data has been added. """ - content_json = json.dumps(content) + content_json = json_encoder.encode(content) with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint @@ -373,7 +372,7 @@ class AccountDataStore(AccountDataWorkerStore): Returns: A deferred that completes once the account_data has been added. """ - content_json = json.dumps(content) + content_json = json_encoder.encode(content) with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 683afde52b..10de446065 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -172,7 +172,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_latest_event_ids_in_room.invalidate((room_id,)) - self.get_unread_message_count_for_user.invalidate_many((room_id,)) self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) if not backfilled: diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 712c8d0264..216a5925fc 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -14,8 +14,7 @@ # limitations under the License. import logging - -from twisted.internet import defer +from typing import Dict, Optional, Tuple from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore @@ -82,21 +81,19 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): "devices_last_seen", self._devices_last_seen_update ) - @defer.inlineCallbacks - def _remove_user_ip_nonunique(self, progress, batch_size): + async def _remove_user_ip_nonunique(self, progress, batch_size): def f(conn): txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.close() - yield self.db_pool.runWithConnection(f) - yield self.db_pool.updates._end_background_update( + await self.db_pool.runWithConnection(f) + await self.db_pool.updates._end_background_update( "user_ips_drop_nonunique_index" ) return 1 - @defer.inlineCallbacks - def _analyze_user_ip(self, progress, batch_size): + async def _analyze_user_ip(self, progress, batch_size): # Background update to analyze user_ips table before we run the # deduplication background update. The table may not have been analyzed # for ages due to the table locks. @@ -106,14 +103,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): def user_ips_analyze(txn): txn.execute("ANALYZE user_ips") - yield self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze) + await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze) - yield self.db_pool.updates._end_background_update("user_ips_analyze") + await self.db_pool.updates._end_background_update("user_ips_analyze") return 1 - @defer.inlineCallbacks - def _remove_user_ip_dupes(self, progress, batch_size): + async def _remove_user_ip_dupes(self, progress, batch_size): # This works function works by scanning the user_ips table in batches # based on `last_seen`. For each row in a batch it searches the rest of # the table to see if there are any duplicates, if there are then they @@ -140,7 +136,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return None # Get a last seen that has roughly `batch_size` since `begin_last_seen` - end_last_seen = yield self.db_pool.runInteraction( + end_last_seen = await self.db_pool.runInteraction( "user_ips_dups_get_last_seen", get_last_seen ) @@ -275,15 +271,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} ) - yield self.db_pool.runInteraction("user_ips_dups_remove", remove) + await self.db_pool.runInteraction("user_ips_dups_remove", remove) if last: - yield self.db_pool.updates._end_background_update("user_ips_remove_dupes") + await self.db_pool.updates._end_background_update("user_ips_remove_dupes") return batch_size - @defer.inlineCallbacks - def _devices_last_seen_update(self, progress, batch_size): + async def _devices_last_seen_update(self, progress, batch_size): """Background update to insert last seen info into devices table """ @@ -346,12 +341,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return len(rows) - updated = yield self.db_pool.runInteraction( + updated = await self.db_pool.runInteraction( "_devices_last_seen_update", _devices_last_seen_update_txn ) if not updated: - yield self.db_pool.updates._end_background_update("devices_last_seen") + await self.db_pool.updates._end_background_update("devices_last_seen") return updated @@ -380,8 +375,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): if self.user_ips_max_age: self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) - @defer.inlineCallbacks - def insert_client_ip( + async def insert_client_ip( self, user_id, access_token, ip, user_agent, device_id, now=None ): if not now: @@ -392,7 +386,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None - yield self.populate_monthly_active_users(user_id) + await self.populate_monthly_active_users(user_id) # Rate-limited inserts if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return @@ -461,25 +455,25 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Failed to upsert, log and continue logger.error("Failed to insert client IP %r: %r", entry, e) - @defer.inlineCallbacks - def get_last_client_ip_by_device(self, user_id, device_id): + async def get_last_client_ip_by_device( + self, user_id: str, device_id: Optional[str] + ) -> Dict[Tuple[str, str], dict]: """For each device_id listed, give the user_ip it was last seen on Args: - user_id (str) - device_id (str): If None fetches all devices for the user + user_id: The user to fetch devices for. + device_id: If None fetches all devices for the user Returns: - defer.Deferred: resolves to a dict, where the keys - are (user_id, device_id) tuples. The values are also dicts, with - keys giving the column names + A dictionary mapping a tuple of (user_id, device_id) to dicts, with + keys giving the column names from the devices table. """ keyvalues = {"user_id": user_id} if device_id is not None: keyvalues["device_id"] = device_id - res = yield self.db_pool.simple_select_list( + res = await self.db_pool.simple_select_list( table="devices", keyvalues=keyvalues, retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), @@ -501,8 +495,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): } return ret - @defer.inlineCallbacks - def get_user_ip_and_agents(self, user): + async def get_user_ip_and_agents(self, user): user_id = user.to_string() results = {} @@ -512,7 +505,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): user_agent, _, last_seen = self._batch_row_update[key] results[(access_token, ip)] = (user_agent, last_seen) - rows = yield self.db_pool.simple_select_list( + rows = await self.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "last_seen"], diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 874ecdf8d2..76ec954f44 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -16,13 +16,12 @@ import logging from typing import List, Tuple -from canonicaljson import json - from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool +from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache logger = logging.getLogger(__name__) @@ -354,7 +353,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) ) rows = [] for destination, edu in remote_messages_by_destination.items(): - edu_json = json.dumps(edu) + edu_json = json_encoder.encode(edu) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) @@ -432,7 +431,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Handle wildcard device_ids. sql = "SELECT device_id FROM devices WHERE user_id = ?" txn.execute(sql, (user_id,)) - message_json = json.dumps(messages_by_device["*"]) + message_json = json_encoder.encode(messages_by_device["*"]) for row in txn: # Add the message for all devices for this user on this # server. @@ -454,7 +453,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Only insert into the local inbox if the device exists on # this server device = row[0] - message_json = json.dumps(messages_by_device[device]) + message_json = json_encoder.encode(messages_by_device[device]) messages_json_for_user[device] = message_json if messages_json_for_user: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 88a7aadfc6..7a5f0bab05 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -17,8 +17,6 @@ import logging from typing import List, Optional, Set, Tuple -from canonicaljson import json - from twisted.internet import defer from synapse.api.errors import Codes, StoreError @@ -36,6 +34,7 @@ from synapse.storage.database import ( make_tuple_comparison_clause, ) from synapse.types import Collection, get_verify_key_from_cross_signing_key +from synapse.util import json_encoder from synapse.util.caches.descriptors import ( Cache, cached, @@ -137,7 +136,9 @@ class DeviceWorkerStore(SQLBaseStore): master_key_by_user = {} self_signing_key_by_user = {} for user in users: - cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master") + cross_signing_key = yield defer.ensureDeferred( + self.get_e2e_cross_signing_key(user, "master") + ) if cross_signing_key: key_id, verify_key = get_verify_key_from_cross_signing_key( cross_signing_key @@ -150,8 +151,8 @@ class DeviceWorkerStore(SQLBaseStore): "device_id": verify_key.version, } - cross_signing_key = yield self.get_e2e_cross_signing_key( - user, "self_signing" + cross_signing_key = yield defer.ensureDeferred( + self.get_e2e_cross_signing_key(user, "self_signing") ) if cross_signing_key: key_id, verify_key = get_verify_key_from_cross_signing_key( @@ -247,7 +248,7 @@ class DeviceWorkerStore(SQLBaseStore): destination (str): The host the device updates are intended for from_stream_id (int): The minimum stream_id to filter updates by, exclusive query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping - user_id/device_id to update stream_id and the relevent json-encoded + user_id/device_id to update stream_id and the relevant json-encoded opentracing context Returns: @@ -397,7 +398,7 @@ class DeviceWorkerStore(SQLBaseStore): values={ "stream_id": stream_id, "from_user_id": from_user_id, - "user_ids": json.dumps(user_ids), + "user_ids": json_encoder.encode(user_ids), }, ) @@ -600,7 +601,7 @@ class DeviceWorkerStore(SQLBaseStore): between the requested tokens due to the limit. The token returned can be used in a subsequent call to this - function to get further updatees. + function to get further updates. The updates are a list of 2-tuples of stream ID and the row data """ @@ -1032,7 +1033,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, - values={"content": json.dumps(content)}, + values={"content": json_encoder.encode(content)}, # we don't need to lock, because we assume we are the only thread # updating this user's devices. lock=False, @@ -1088,7 +1089,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): { "user_id": user_id, "device_id": content["device_id"], - "content": json.dumps(content), + "content": json_encoder.encode(content), } for content in devices ], @@ -1209,7 +1210,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): "device_id": device_id, "sent": False, "ts": now, - "opentracing_context": json.dumps(context) + "opentracing_context": json_encoder.encode(context) if whitelisted_homeserver(destination) else "{}", } diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 7819bfcbb3..037e02603c 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -14,30 +14,29 @@ # limitations under the License. from collections import namedtuple -from typing import Optional - -from twisted.internet import defer +from typing import Iterable, Optional from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore +from synapse.types import RoomAlias from synapse.util.caches.descriptors import cached RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) class DirectoryWorkerStore(SQLBaseStore): - @defer.inlineCallbacks - def get_association_from_room_alias(self, room_alias): - """ Get's the room_id and server list for a given room_alias + async def get_association_from_room_alias( + self, room_alias: RoomAlias + ) -> Optional[RoomAliasMapping]: + """Gets the room_id and server list for a given room_alias Args: - room_alias (RoomAlias) + room_alias: The alias to translate to an ID. Returns: - Deferred: results in namedtuple with keys "room_id" and - "servers" or None if no association can be found + The room alias mapping or None if no association can be found. """ - room_id = yield self.db_pool.simple_select_one_onecol( + room_id = await self.db_pool.simple_select_one_onecol( "room_aliases", {"room_alias": room_alias.to_string()}, "room_id", @@ -48,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore): if not room_id: return None - servers = yield self.db_pool.simple_select_onecol( + servers = await self.db_pool.simple_select_onecol( "room_alias_servers", {"room_alias": room_alias.to_string()}, "server", @@ -79,18 +78,20 @@ class DirectoryWorkerStore(SQLBaseStore): class DirectoryStore(DirectoryWorkerStore): - @defer.inlineCallbacks - def create_room_alias_association(self, room_alias, room_id, servers, creator=None): + async def create_room_alias_association( + self, + room_alias: RoomAlias, + room_id: str, + servers: Iterable[str], + creator: Optional[str] = None, + ) -> None: """ Creates an association between a room alias and room_id/servers Args: - room_alias (RoomAlias) - room_id (str) - servers (list) - creator (str): Optional user_id of creator. - - Returns: - Deferred + room_alias: The alias to create. + room_id: The target of the alias. + servers: A list of servers through which it may be possible to join the room + creator: Optional user_id of creator. """ def alias_txn(txn): @@ -118,24 +119,22 @@ class DirectoryStore(DirectoryWorkerStore): ) try: - ret = yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "create_room_alias_association", alias_txn ) except self.database_engine.module.IntegrityError: raise SynapseError( 409, "Room alias %s already exists" % room_alias.to_string() ) - return ret - @defer.inlineCallbacks - def delete_room_alias(self, room_alias): - room_id = yield self.db_pool.runInteraction( + async def delete_room_alias(self, room_alias: RoomAlias) -> str: + room_id = await self.db_pool.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias ) return room_id - def _delete_room_alias_txn(self, txn, room_alias): + def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str: txn.execute( "SELECT room_id FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),), diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 90152edc3c..2eeb9f97dc 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -14,18 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from canonicaljson import json - -from twisted.internet import defer - from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.util import json_encoder class EndToEndRoomKeyStore(SQLBaseStore): - @defer.inlineCallbacks - def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + async def update_e2e_room_key( + self, user_id, version, room_id, session_id, room_key + ): """Replaces the encrypted E2E room key for a given session in a given backup Args: @@ -38,7 +36,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): StoreError """ - yield self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, @@ -50,13 +48,12 @@ class EndToEndRoomKeyStore(SQLBaseStore): "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], - "session_data": json.dumps(room_key["session_data"]), + "session_data": json_encoder.encode(room_key["session_data"]), }, desc="update_e2e_room_key", ) - @defer.inlineCallbacks - def add_e2e_room_keys(self, user_id, version, room_keys): + async def add_e2e_room_keys(self, user_id, version, room_keys): """Bulk add room keys to a given backup. Args: @@ -77,7 +74,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], - "session_data": json.dumps(room_key["session_data"]), + "session_data": json_encoder.encode(room_key["session_data"]), } ) log_kv( @@ -89,13 +86,12 @@ class EndToEndRoomKeyStore(SQLBaseStore): } ) - yield self.db_pool.simple_insert_many( + await self.db_pool.simple_insert_many( table="e2e_room_keys", values=values, desc="add_e2e_room_keys" ) @trace - @defer.inlineCallbacks - def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. @@ -110,7 +106,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): the backup (or for the specified room) Returns: - A deferred list of dicts giving the session_data and message metadata for + A list of dicts giving the session_data and message metadata for these room keys. """ @@ -125,7 +121,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - rows = yield self.db_pool.simple_select_list( + rows = await self.db_pool.simple_select_list( table="e2e_room_keys", keyvalues=keyvalues, retcols=( @@ -243,8 +239,9 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - @defer.inlineCallbacks - def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_e2e_room_keys( + self, user_id, version, room_id=None, session_id=None + ): """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. @@ -259,7 +256,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): the backup (or for the specified room) Returns: - A deferred of the deletion transaction + The deletion transaction """ keyvalues = {"user_id": user_id, "version": int(version)} @@ -268,7 +265,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - yield self.db_pool.simple_delete( + await self.db_pool.simple_delete( table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" ) @@ -360,7 +357,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "user_id": user_id, "version": new_version, "algorithm": info["algorithm"], - "auth_data": json.dumps(info["auth_data"]), + "auth_data": json_encoder.encode(info["auth_data"]), }, ) @@ -387,7 +384,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): updatevalues = {} if info is not None and "auth_data" in info: - updatevalues["auth_data"] = json.dumps(info["auth_data"]) + updatevalues["auth_data"] = json_encoder.encode(info["auth_data"]) if version_etag is not None: updatevalues["etag"] = version_etag diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 40354b8304..f93e0d320d 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,24 +14,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple +from typing import Dict, Iterable, List, Optional, Tuple -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.enterprise.adbapi import Connection -from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import make_in_list_sql_clause +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter class EndToEndKeyWorkerStore(SQLBaseStore): @trace - @defer.inlineCallbacks - def get_e2e_device_keys( + async def get_e2e_device_keys( self, query_list, include_all_devices=False, include_deleted_devices=False ): """Fetch a list of device keys. @@ -51,7 +50,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): if not query_list: return {} - results = yield self.db_pool.runInteraction( + results = await self.db_pool.runInteraction( "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, @@ -174,8 +173,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore): log_kv(result) return result - @defer.inlineCallbacks - def get_e2e_one_time_keys(self, user_id, device_id, key_ids): + async def get_e2e_one_time_keys( + self, user_id: str, device_id: str, key_ids: List[str] + ) -> Dict[Tuple[str, str], str]: """Retrieve a number of one-time keys for a user Args: @@ -185,11 +185,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): retrieve Returns: - deferred resolving to Dict[(str, str), str]: map from (algorithm, - key_id) to json string for key + A map from (algorithm, key_id) to json string for key """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="e2e_one_time_keys_json", column="key_id", iterable=key_ids, @@ -201,17 +200,21 @@ class EndToEndKeyWorkerStore(SQLBaseStore): log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) return result - @defer.inlineCallbacks - def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): + async def add_e2e_one_time_keys( + self, + user_id: str, + device_id: str, + time_now: int, + new_keys: Iterable[Tuple[str, str, str]], + ) -> None: """Insert some new one time keys for a device. Errors if any of the keys already exist. Args: - user_id(str): id of user to get keys for - device_id(str): id of device to get keys for - time_now(long): insertion time to record (ms since epoch) - new_keys(iterable[(str, str, str)]: keys to add - each a tuple of - (algorithm, key_id, key json) + user_id: id of user to get keys for + device_id: id of device to get keys for + time_now: insertion time to record (ms since epoch) + new_keys: keys to add - each a tuple of (algorithm, key_id, key json) """ def _add_e2e_one_time_keys(txn): @@ -241,7 +244,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys ) @@ -268,22 +271,23 @@ class EndToEndKeyWorkerStore(SQLBaseStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) - @defer.inlineCallbacks - def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): + async def get_e2e_cross_signing_key( + self, user_id: str, key_type: str, from_user_id: Optional[str] = None + ) -> Optional[dict]: """Returns a user's cross-signing key. Args: - user_id (str): the user whose key is being requested - key_type (str): the type of key that is being requested: either 'master' + user_id: the user whose key is being requested + key_type: the type of key that is being requested: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key - from_user_id (str): if specified, signatures made by this user on + from_user_id: if specified, signatures made by this user on the self-signing key will be included in the result Returns: dict of the key data or None if not found """ - res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) + res = await self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) user_keys = res.get(user_id) if not user_keys: return None @@ -449,28 +453,26 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return keys - @defer.inlineCallbacks - def get_e2e_cross_signing_keys_bulk( - self, user_ids: List[str], from_user_id: str = None - ) -> defer.Deferred: + async def get_e2e_cross_signing_keys_bulk( + self, user_ids: List[str], from_user_id: Optional[str] = None + ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. Args: - user_ids (list[str]): the users whose keys are being requested - from_user_id (str): if specified, signatures made by this user on + user_ids: the users whose keys are being requested + from_user_id: if specified, signatures made by this user on the self-signing keys will be included in the result Returns: - Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to - key data. If a user's cross-signing keys were not found, either - their user ID will not be in the dict, or their user ID will map - to None. + A map of user ID to key type to key data. If a user's cross-signing + keys were not found, either their user ID will not be in the dict, + or their user ID will map to None. """ - result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) + result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids) if from_user_id: - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( "get_e2e_cross_signing_signatures", self._get_e2e_cross_signing_signatures_txn, result, @@ -700,7 +702,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): values={ "user_id": user_id, "keytype": key_type, - "keydata": json.dumps(key), + "keydata": json_encoder.encode(key), "stream_id": stream_id, }, ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index b8cefb4d5e..7c246d3e4c 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -17,11 +17,10 @@ import logging from typing import List -from canonicaljson import json - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool +from synapse.util import json_encoder from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -50,7 +49,7 @@ def _serialize_action(actions, is_highlight): else: if actions == DEFAULT_NOTIF_ACTION: return "" - return json.dumps(actions) + return json_encoder.encode(actions) def _deserialize_action(actions, is_highlight): diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 4d8a24ce4b..1a68bf32cb 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -53,47 +53,6 @@ event_counter = Counter( ["type", "origin_type", "origin_entity"], ) -STATE_EVENT_TYPES_TO_MARK_UNREAD = { - EventTypes.Topic, - EventTypes.Name, - EventTypes.RoomAvatar, - EventTypes.Tombstone, -} - - -def should_count_as_unread(event: EventBase, context: EventContext) -> bool: - # Exclude rejected and soft-failed events. - if context.rejected or event.internal_metadata.is_soft_failed(): - return False - - # Exclude notices. - if ( - not event.is_state() - and event.type == EventTypes.Message - and event.content.get("msgtype") == "m.notice" - ): - return False - - # Exclude edits. - relates_to = event.content.get("m.relates_to", {}) - if relates_to.get("rel_type") == RelationTypes.REPLACE: - return False - - # Mark events that have a non-empty string body as unread. - body = event.content.get("body") - if isinstance(body, str) and body: - return True - - # Mark some state events as unread. - if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD: - return True - - # Mark encrypted events as unread. - if not event.is_state() and event.type == EventTypes.Encrypted: - return True - - return False - def encode_json(json_object): """ @@ -239,10 +198,6 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() - self.store.get_unread_message_count_for_user.invalidate_many( - (event.room_id,), - ) - for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) @@ -864,9 +819,8 @@ class PersistEventsStore: "contains_url": ( "url" in event.content and isinstance(event.content["url"], str) ), - "count_as_unread": should_count_as_unread(event, context), } - for event, context in events_and_contexts + for event, _ in events_and_contexts ], ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a7b7393f6e..755b7a2a85 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -41,15 +41,9 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.types import Cursor from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import ( - Cache, - _CacheContext, - cached, - cachedInlineCallbacks, -) +from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -1364,84 +1358,6 @@ class EventsWorkerStore(SQLBaseStore): desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) - @cached(tree=True, cache_context=True) - async def get_unread_message_count_for_user( - self, room_id: str, user_id: str, cache_context: _CacheContext, - ) -> int: - """Retrieve the count of unread messages for the given room and user. - - Args: - room_id: The ID of the room to count unread messages in. - user_id: The ID of the user to count unread messages for. - - Returns: - The number of unread messages for the given user in the given room. - """ - with Measure(self._clock, "get_unread_message_count_for_user"): - last_read_event_id = await self.get_last_receipt_event_id_for_user( - user_id=user_id, - room_id=room_id, - receipt_type="m.read", - on_invalidate=cache_context.invalidate, - ) - - return await self.db_pool.runInteraction( - "get_unread_message_count_for_user", - self._get_unread_message_count_for_user_txn, - user_id, - room_id, - last_read_event_id, - ) - - def _get_unread_message_count_for_user_txn( - self, - txn: Cursor, - user_id: str, - room_id: str, - last_read_event_id: Optional[str], - ) -> int: - if last_read_event_id: - # Get the stream ordering for the last read event. - stream_ordering = self.db_pool.simple_select_one_onecol_txn( - txn=txn, - table="events", - keyvalues={"room_id": room_id, "event_id": last_read_event_id}, - retcol="stream_ordering", - ) - else: - # If there's no read receipt for that room, it probably means the user hasn't - # opened it yet, in which case use the stream ID of their join event. - # We can't just set it to 0 otherwise messages from other local users from - # before this user joined will be counted as well. - txn.execute( - """ - SELECT stream_ordering FROM local_current_membership - LEFT JOIN events USING (event_id, room_id) - WHERE membership = 'join' - AND user_id = ? - AND room_id = ? - """, - (user_id, room_id), - ) - row = txn.fetchone() - - if row is None: - return 0 - - stream_ordering = row[0] - - # Count the messages that qualify as unread after the stream ordering we've just - # retrieved. - sql = """ - SELECT COUNT(*) FROM events - WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread - """ - - txn.execute(sql, (user_id, room_id, stream_ordering)) - row = txn.fetchone() - - return row[0] if row else 0 - AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index a98181f445..75ea6d4b2f 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -16,12 +16,11 @@ from typing import List, Tuple -from canonicaljson import json - from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.util import json_encoder # The category ID for the "default" category. We don't store as null in the # database to avoid the fun of null != null @@ -752,7 +751,7 @@ class GroupServerStore(GroupServerWorkerStore): if profile is None: insertion_values["profile"] = "{}" else: - update_values["profile"] = json.dumps(profile) + update_values["profile"] = json_encoder.encode(profile) if is_public is None: insertion_values["is_public"] = True @@ -783,7 +782,7 @@ class GroupServerStore(GroupServerWorkerStore): if profile is None: insertion_values["profile"] = "{}" else: - update_values["profile"] = json.dumps(profile) + update_values["profile"] = json_encoder.encode(profile) if is_public is None: insertion_values["is_public"] = True @@ -1007,7 +1006,7 @@ class GroupServerStore(GroupServerWorkerStore): "group_id": group_id, "user_id": user_id, "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json.dumps(remote_attestation), + "attestation_json": json_encoder.encode(remote_attestation), }, ) @@ -1131,7 +1130,7 @@ class GroupServerStore(GroupServerWorkerStore): "is_admin": is_admin, "membership": membership, "is_publicised": is_publicised, - "content": json.dumps(content), + "content": json_encoder.encode(content), }, ) @@ -1143,7 +1142,7 @@ class GroupServerStore(GroupServerWorkerStore): "group_id": group_id, "user_id": user_id, "type": "membership", - "content": json.dumps( + "content": json_encoder.encode( {"membership": membership, "content": content} ), }, @@ -1171,7 +1170,7 @@ class GroupServerStore(GroupServerWorkerStore): "group_id": group_id, "user_id": user_id, "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json.dumps(remote_attestation), + "attestation_json": json_encoder.encode(remote_attestation), }, ) else: @@ -1240,7 +1239,7 @@ class GroupServerStore(GroupServerWorkerStore): keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ "valid_until_ms": attestation["valid_until_ms"], - "attestation_json": json.dumps(attestation), + "attestation_json": json_encoder.encode(attestation), }, desc="update_remote_attestion", ) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 02b01d9619..e71cdd2cb4 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -15,8 +15,6 @@ import logging from typing import List -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.util.caches.descriptors import cached @@ -252,16 +250,12 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): "reap_monthly_active_users", _reap_users, reserved_users ) - @defer.inlineCallbacks - def upsert_monthly_active_user(self, user_id): + async def upsert_monthly_active_user(self, user_id: str) -> None: """Updates or inserts the user into the monthly active user table, which is used to track the current MAU usage of the server Args: - user_id (str): user to add/update - - Returns: - Deferred + user_id: user to add/update """ # Support user never to be included in MAU stats. Note I can't easily call this # from upsert_monthly_active_user_txn because then I need a _txn form of @@ -271,11 +265,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): # _initialise_reserved_users reasoning that it would be very strange to # include a support user in this context. - is_support = yield self.is_support_user(user_id) + is_support = await self.is_support_user(user_id) if is_support: return - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) @@ -322,8 +316,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): return is_insert - @defer.inlineCallbacks - def populate_monthly_active_users(self, user_id): + async def populate_monthly_active_users(self, user_id): """Checks on the state of monthly active user limits and optionally add the user to the monthly active tables @@ -332,14 +325,14 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): """ if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group - is_guest = yield self.is_guest(user_id) + is_guest = await self.is_guest(user_id) if is_guest: return - is_trial = yield self.is_trial_user(user_id) + is_trial = await self.is_trial_user(user_id) if is_trial: return - last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) + last_seen_timestamp = await self.user_last_seen_monthly_active(user_id) now = self.hs.get_clock().time_msec() # We want to reduce to the total number of db writes, and are happy @@ -352,10 +345,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): # False, there is no point in checking get_monthly_active_count - it # adds no value and will break the logic if max_mau_value is exceeded. if not self._limit_usage_by_mau: - yield self.upsert_monthly_active_user(user_id) + await self.upsert_monthly_active_user(user_id) else: - count = yield self.get_monthly_active_count() + count = await self.get_monthly_active_count() if count < self._max_mau_value: - yield self.upsert_monthly_active_user(user_id) + await self.upsert_monthly_active_user(user_id) elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: - yield self.upsert_monthly_active_user(user_id) + await self.upsert_monthly_active_user(user_id) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 5fd899326a..19a0211a03 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -18,8 +18,6 @@ import abc import logging from typing import List, Tuple, Union -from canonicaljson import json - from twisted.internet import defer from synapse.push.baserules import list_with_base_rules @@ -33,6 +31,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.util.id_generators import ChainedIdGenerator +from synapse.util import json_encoder from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -419,8 +418,8 @@ class PushRuleStore(PushRulesWorkerStore): before=None, after=None, ): - conditions_json = json.dumps(conditions) - actions_json = json.dumps(actions) + conditions_json = json_encoder.encode(conditions) + actions_json = json_encoder.encode(actions) with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids if before or after: @@ -689,7 +688,7 @@ class PushRuleStore(PushRulesWorkerStore): @defer.inlineCallbacks def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): - actions_json = json.dumps(actions) + actions_json = json_encoder.encode(actions) def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): if is_default_rule: diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 6255977c92..1920a8a152 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -18,13 +18,12 @@ import abc import logging from typing import List, Tuple -from canonicaljson import json - from twisted.internet import defer from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -459,7 +458,7 @@ class ReceiptsStore(ReceiptsWorkerStore): values={ "stream_id": stream_id, "event_id": event_id, - "data": json.dumps(data), + "data": json_encoder.encode(data), }, # receipts_linearized has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock @@ -585,7 +584,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, - "event_ids": json.dumps(event_ids), - "data": json.dumps(data), + "event_ids": json_encoder.encode(event_ids), + "data": json_encoder.encode(data), }, ) diff --git a/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql b/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql deleted file mode 100644 index 531b532c73..0000000000 --- a/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Store a boolean value in the events table for whether the event should be counted in --- the unread_count property of sync responses. -ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN; diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 2162d0712d..7f8d1880e5 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -16,8 +16,7 @@ import logging import re from collections import namedtuple - -from twisted.internet import defer +from typing import List, Optional from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -114,8 +113,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search ) - @defer.inlineCallbacks - def _background_reindex_search(self, progress, batch_size): + async def _background_reindex_search(self, progress, batch_size): # we work through the events table from highest stream id to lowest target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] @@ -206,19 +204,18 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): return len(event_search_rows) - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn ) if not result: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_SEARCH_UPDATE_NAME ) return result - @defer.inlineCallbacks - def _background_reindex_gin_search(self, progress, batch_size): + async def _background_reindex_gin_search(self, progress, batch_size): """This handles old synapses which used GIST indexes, if any; converting them back to be GIN as per the actual schema. """ @@ -255,15 +252,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): conn.set_session(autocommit=False) if isinstance(self.database_engine, PostgresEngine): - yield self.db_pool.runWithConnection(create_index) + await self.db_pool.runWithConnection(create_index) - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME ) return 1 - @defer.inlineCallbacks - def _background_reindex_search_order(self, progress, batch_size): + async def _background_reindex_search_order(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) @@ -288,12 +284,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): ) conn.set_session(autocommit=False) - yield self.db_pool.runWithConnection(create_index) + await self.db_pool.runWithConnection(create_index) pg = dict(progress) pg["have_added_indexes"] = True - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, self.db_pool.updates._background_update_progress_txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, @@ -331,12 +327,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): return len(rows), True - num_rows, finished = yield self.db_pool.runInteraction( + num_rows, finished = await self.db_pool.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn ) if not finished: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_SEARCH_ORDER_UPDATE_NAME ) @@ -347,8 +343,7 @@ class SearchStore(SearchBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(SearchStore, self).__init__(database, db_conn, hs) - @defer.inlineCallbacks - def search_msgs(self, room_ids, search_term, keys): + async def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. Args: @@ -425,7 +420,7 @@ class SearchStore(SearchBackgroundUpdateStore): # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" - results = yield self.db_pool.execute( + results = await self.db_pool.execute( "search_msgs", self.db_pool.cursor_to_dict, sql, *args ) @@ -433,7 +428,7 @@ class SearchStore(SearchBackgroundUpdateStore): # We set redact_behaviour to BLOCK here to prevent redacted events being returned in # search results (which is a data leak) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r["event_id"] for r in results], redact_behaviour=EventRedactBehaviour.BLOCK, ) @@ -442,11 +437,11 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = yield self._find_highlights_in_postgres(search_query, events) + highlights = await self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" - count_results = yield self.db_pool.execute( + count_results = await self.db_pool.execute( "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args ) @@ -462,19 +457,25 @@ class SearchStore(SearchBackgroundUpdateStore): "count": count, } - @defer.inlineCallbacks - def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None): + async def search_rooms( + self, + room_ids: List[str], + search_term: str, + keys: List[str], + limit, + pagination_token: Optional[str] = None, + ) -> List[dict]: """Performs a full text search over events with given keys. Args: - room_id (list): The room_ids to search in - search_term (str): Search term to search for - keys (list): List of keys to search in, currently supports - "content.body", "content.name", "content.topic" - pagination_token (str): A pagination token previously returned + room_ids: The room_ids to search in + search_term: Search term to search for + keys: List of keys to search in, currently supports "content.body", + "content.name", "content.topic" + pagination_token: A pagination token previously returned Returns: - list of dicts + Each match as a dictionary. """ clauses = [] @@ -577,7 +578,7 @@ class SearchStore(SearchBackgroundUpdateStore): args.append(limit) - results = yield self.db_pool.execute( + results = await self.db_pool.execute( "search_rooms", self.db_pool.cursor_to_dict, sql, *args ) @@ -585,7 +586,7 @@ class SearchStore(SearchBackgroundUpdateStore): # We set redact_behaviour to BLOCK here to prevent redacted events being returned in # search results (which is a data leak) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r["event_id"] for r in results], redact_behaviour=EventRedactBehaviour.BLOCK, ) @@ -594,11 +595,11 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = yield self._find_highlights_in_postgres(search_query, events) + highlights = await self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" - count_results = yield self.db_pool.execute( + count_results = await self.db_pool.execute( "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args ) diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index dae8e8bd29..be191dd870 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -15,8 +15,6 @@ from unpaddedbase64 import encode_base64 -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedList @@ -40,9 +38,8 @@ class SignatureWorkerStore(SQLBaseStore): return self.db_pool.runInteraction("get_event_reference_hashes", f) - @defer.inlineCallbacks - def add_event_hashes(self, event_ids): - hashes = yield self.get_event_reference_hashes(event_ids) + async def add_event_hashes(self, event_ids): + hashes = await self.get_event_reference_hashes(event_ids) hashes = { e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} for e_id, h in hashes.items() diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index d73a8e8ab9..af21fe457a 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -16,8 +16,6 @@ import logging import re -from twisted.internet import defer - from synapse.api.constants import EventTypes, JoinRules from synapse.storage.database import DatabasePool from synapse.storage.databases.main.state import StateFilter @@ -59,8 +57,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): "populate_user_directory_cleanup", self._populate_user_directory_cleanup ) - @defer.inlineCallbacks - def _populate_user_directory_createtables(self, progress, batch_size): + async def _populate_user_directory_createtables(self, progress, batch_size): # Get all the rooms that we want to process. def _make_staging_area(txn): @@ -102,45 +99,43 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) - new_pos = yield self.get_max_stream_id_in_current_state_deltas() - yield self.db_pool.runInteraction( + new_pos = await self.get_max_stream_id_in_current_state_deltas() + await self.db_pool.runInteraction( "populate_user_directory_temp_build", _make_staging_area ) - yield self.db_pool.simple_insert( + await self.db_pool.simple_insert( TEMP_TABLE + "_position", {"position": new_pos} ) - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_createtables" ) return 1 - @defer.inlineCallbacks - def _populate_user_directory_cleanup(self, progress, batch_size): + async def _populate_user_directory_cleanup(self, progress, batch_size): """ Update the user directory stream position, then clean up the old tables. """ - position = yield self.db_pool.simple_select_one_onecol( + position = await self.db_pool.simple_select_one_onecol( TEMP_TABLE + "_position", None, "position" ) - yield self.update_user_directory_stream_pos(position) + await self.update_user_directory_stream_pos(position) def _delete_staging_area(txn): txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "populate_user_directory_cleanup", _delete_staging_area ) - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_cleanup" ) return 1 - @defer.inlineCallbacks - def _populate_user_directory_process_rooms(self, progress, batch_size): + async def _populate_user_directory_process_rooms(self, progress, batch_size): """ Args: progress (dict) @@ -151,7 +146,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # If we don't have progress filed, delete everything. if not progress: - yield self.delete_all_from_user_dir() + await self.delete_all_from_user_dir() def _get_next_batch(txn): # Only fetch 250 rooms, so we don't fetch too many at once, even @@ -176,13 +171,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return rooms_to_work_on - rooms_to_work_on = yield self.db_pool.runInteraction( + rooms_to_work_on = await self.db_pool.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_process_rooms" ) return 1 @@ -195,21 +190,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): processed_event_count = 0 for room_id, event_count in rooms_to_work_on: - is_in_room = yield self.is_host_joined(room_id, self.server_name) + is_in_room = await self.is_host_joined(room_id, self.server_name) if is_in_room: - is_public = yield self.is_room_world_readable_or_publicly_joinable( + is_public = await self.is_room_world_readable_or_publicly_joinable( room_id ) - users_with_profile = yield defer.ensureDeferred( - state.get_current_users_in_room(room_id) - ) + users_with_profile = await state.get_current_users_in_room(room_id) user_ids = set(users_with_profile) # Update each user in the user directory. for user_id, profile in users_with_profile.items(): - yield self.update_profile_in_user_dir( + await self.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) @@ -223,7 +216,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): to_insert.add(user_id) if to_insert: - yield self.add_users_in_public_rooms(room_id, to_insert) + await self.add_users_in_public_rooms(room_id, to_insert) to_insert.clear() else: for user_id in user_ids: @@ -243,22 +236,22 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # If it gets too big, stop and write to the database # to prevent storing too much in RAM. if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET: - yield self.add_users_who_share_private_room( + await self.add_users_who_share_private_room( room_id, to_insert ) to_insert.clear() if to_insert: - yield self.add_users_who_share_private_room(room_id, to_insert) + await self.add_users_who_share_private_room(room_id, to_insert) to_insert.clear() # We've finished a room. Delete it from the table. - yield self.db_pool.simple_delete_one( + await self.db_pool.simple_delete_one( TEMP_TABLE + "_rooms", {"room_id": room_id} ) # Update the remaining counter. progress["remaining"] -= 1 - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "populate_user_directory", self.db_pool.updates._background_update_progress_txn, "populate_user_directory_process_rooms", @@ -273,13 +266,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return processed_event_count - @defer.inlineCallbacks - def _populate_user_directory_process_users(self, progress, batch_size): + async def _populate_user_directory_process_users(self, progress, batch_size): """ If search_all_users is enabled, add all of the users to the user directory. """ if not self.hs.config.user_directory_search_all_users: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_process_users" ) return 1 @@ -305,13 +297,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return users_to_work_on - users_to_work_on = yield self.db_pool.runInteraction( + users_to_work_on = await self.db_pool.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) # No more users -- complete the transaction. if not users_to_work_on: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( "populate_user_directory_process_users" ) return 1 @@ -322,18 +314,18 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) for user_id in users_to_work_on: - profile = yield self.get_profileinfo(get_localpart_from_id(user_id)) - yield self.update_profile_in_user_dir( + profile = await self.get_profileinfo(get_localpart_from_id(user_id)) + await self.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) # We've finished processing a user. Delete it from the table. - yield self.db_pool.simple_delete_one( + await self.db_pool.simple_delete_one( TEMP_TABLE + "_users", {"user_id": user_id} ) # Update the remaining counter. progress["remaining"] -= 1 - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "populate_user_directory", self.db_pool.updates._background_update_progress_txn, "populate_user_directory_process_users", @@ -342,8 +334,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return len(users_to_work_on) - @defer.inlineCallbacks - def is_room_world_readable_or_publicly_joinable(self, room_id): + async def is_room_world_readable_or_publicly_joinable(self, room_id): """Check if the room is either world_readable or publically joinable """ @@ -353,20 +344,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): (EventTypes.RoomHistoryVisibility, ""), ) - current_state_ids = yield self.get_filtered_current_state_ids( + current_state_ids = await self.get_filtered_current_state_ids( room_id, StateFilter.from_types(types_to_filter) ) join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) if join_rules_id: - join_rule_ev = yield self.get_event(join_rules_id, allow_none=True) + join_rule_ev = await self.get_event(join_rules_id, allow_none=True) if join_rule_ev: if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: return True hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) if hist_vis_id: - hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True) + hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) if hist_vis_ev: if hist_vis_ev.content.get("history_visibility") == "world_readable": return True @@ -590,19 +581,18 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): "remove_from_user_dir", _remove_from_user_dir_txn ) - @defer.inlineCallbacks - def get_users_in_dir_due_to_room(self, room_id): + async def get_users_in_dir_due_to_room(self, room_id): """Get all user_ids that are in the room directory because they're in the given room_id """ - user_ids_share_pub = yield self.db_pool.simple_select_onecol( + user_ids_share_pub = await self.db_pool.simple_select_onecol( table="users_in_public_rooms", keyvalues={"room_id": room_id}, retcol="user_id", desc="get_users_in_dir_due_to_room", ) - user_ids_share_priv = yield self.db_pool.simple_select_onecol( + user_ids_share_priv = await self.db_pool.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"room_id": room_id}, retcol="other_user_id", @@ -645,8 +635,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): "remove_user_who_share_room", _remove_user_who_share_room_txn ) - @defer.inlineCallbacks - def get_user_dir_rooms_user_is_in(self, user_id): + async def get_user_dir_rooms_user_is_in(self, user_id): """ Returns the rooms that a user is in. @@ -656,14 +645,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): Returns: list: user_id """ - rows = yield self.db_pool.simple_select_onecol( + rows = await self.db_pool.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, retcol="room_id", desc="get_rooms_user_is_in", ) - pub_rows = yield self.db_pool.simple_select_onecol( + pub_rows = await self.db_pool.simple_select_onecol( table="users_in_public_rooms", keyvalues={"user_id": user_id}, retcol="room_id", @@ -674,32 +663,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - @defer.inlineCallbacks - def get_rooms_in_common_for_users(self, user_id, other_user_id): - """Given two user_ids find out the list of rooms they share. - """ - sql = """ - SELECT room_id FROM ( - SELECT c.room_id FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE type = 'm.room.member' - AND m.membership = 'join' - AND state_key = ? - ) AS f1 INNER JOIN ( - SELECT c.room_id FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE type = 'm.room.member' - AND m.membership = 'join' - AND state_key = ? - ) f2 USING (room_id) - """ - - rows = yield self.db_pool.execute( - "get_rooms_in_common_for_users", None, sql, user_id, other_user_id - ) - - return [room_id for room_id, in rows] - def get_user_directory_stream_pos(self): return self.db_pool.simple_select_one_onecol( table="user_directory_stream_pos", @@ -708,8 +671,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): desc="get_user_directory_stream_pos", ) - @defer.inlineCallbacks - def search_user_dir(self, user_id, search_term, limit): + async def search_user_dir(self, user_id, search_term, limit): """Searches for users in directory Returns: @@ -806,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # This should be unreachable. raise Exception("Unrecognized database engine") - results = yield self.db_pool.execute( + results = await self.db_pool.execute( "search_user_dir", self.db_pool.cursor_to_dict, sql, *args ) |