From a7bdf98d01d2225a479753a85ba81adf02b16a32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2020 21:38:57 +0100 Subject: Rename database classes to make some sense (#8033) --- synapse/storage/databases/main/deviceinbox.py | 476 ++++++++++++++++++++++++++ 1 file changed, 476 insertions(+) create mode 100644 synapse/storage/databases/main/deviceinbox.py (limited to 'synapse/storage/databases/main/deviceinbox.py') diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py new file mode 100644 index 0000000000..874ecdf8d2 --- /dev/null +++ b/synapse/storage/databases/main/deviceinbox.py @@ -0,0 +1,476 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List, Tuple + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import DatabasePool +from synapse.util.caches.expiringcache import ExpiringCache + +logger = logging.getLogger(__name__) + + +class DeviceInboxWorkerStore(SQLBaseStore): + def get_to_device_stream_token(self): + return self._device_inbox_id_gen.get_current_token() + + def get_new_messages_for_device( + self, user_id, device_id, last_stream_id, current_stream_id, limit=100 + ): + """ + Args: + user_id(str): The recipient user_id. + device_id(str): The recipient device_id. + current_stream_id(int): The current position of the to device + message stream. + Returns: + Deferred ([dict], int): List of messages for the device and where + in the stream the messages got to. + """ + has_changed = self._device_inbox_stream_cache.has_entity_changed( + user_id, last_stream_id + ) + if not has_changed: + return defer.succeed(([], current_stream_id)) + + def get_new_messages_for_device_txn(txn): + sql = ( + "SELECT stream_id, message_json FROM device_inbox" + " WHERE user_id = ? AND device_id = ?" + " AND ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute( + sql, (user_id, device_id, last_stream_id, current_stream_id, limit) + ) + messages = [] + for row in txn: + stream_pos = row[0] + messages.append(db_to_json(row[1])) + if len(messages) < limit: + stream_pos = current_stream_id + return messages, stream_pos + + return self.db_pool.runInteraction( + "get_new_messages_for_device", get_new_messages_for_device_txn + ) + + @trace + @defer.inlineCallbacks + def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): + """ + Args: + user_id(str): The recipient user_id. + device_id(str): The recipient device_id. + up_to_stream_id(int): Where to delete messages up to. + Returns: + A deferred that resolves to the number of messages deleted. + """ + # If we have cached the last stream id we've deleted up to, we can + # check if there is likely to be anything that needs deleting + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), None + ) + + set_tag("last_deleted_stream_id", last_deleted_stream_id) + + if last_deleted_stream_id: + has_changed = self._device_inbox_stream_cache.has_entity_changed( + user_id, last_deleted_stream_id + ) + if not has_changed: + log_kv({"message": "No changes in cache since last check"}) + return 0 + + def delete_messages_for_device_txn(txn): + sql = ( + "DELETE FROM device_inbox" + " WHERE user_id = ? AND device_id = ?" + " AND stream_id <= ?" + ) + txn.execute(sql, (user_id, device_id, up_to_stream_id)) + return txn.rowcount + + count = yield self.db_pool.runInteraction( + "delete_messages_for_device", delete_messages_for_device_txn + ) + + log_kv( + {"message": "deleted {} messages for device".format(count), "count": count} + ) + + # Update the cache, ensuring that we only ever increase the value + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), 0 + ) + self._last_device_delete_cache[(user_id, device_id)] = max( + last_deleted_stream_id, up_to_stream_id + ) + + return count + + @trace + def get_new_device_msgs_for_remote( + self, destination, last_stream_id, current_stream_id, limit + ): + """ + Args: + destination(str): The name of the remote server. + last_stream_id(int|long): The last position of the device message stream + that the server sent up to. + current_stream_id(int|long): The current position of the device + message stream. + Returns: + Deferred ([dict], int|long): List of messages for the device and where + in the stream the messages got to. + """ + + set_tag("destination", destination) + set_tag("last_stream_id", last_stream_id) + set_tag("current_stream_id", current_stream_id) + set_tag("limit", limit) + + has_changed = self._device_federation_outbox_stream_cache.has_entity_changed( + destination, last_stream_id + ) + if not has_changed or last_stream_id == current_stream_id: + log_kv({"message": "No new messages in stream"}) + return defer.succeed(([], current_stream_id)) + + if limit <= 0: + # This can happen if we run out of room for EDUs in the transaction. + return defer.succeed(([], last_stream_id)) + + @trace + def get_new_messages_for_remote_destination_txn(txn): + sql = ( + "SELECT stream_id, messages_json FROM device_federation_outbox" + " WHERE destination = ?" + " AND ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, (destination, last_stream_id, current_stream_id, limit)) + messages = [] + for row in txn: + stream_pos = row[0] + messages.append(db_to_json(row[1])) + if len(messages) < limit: + log_kv({"message": "Set stream position to current position"}) + stream_pos = current_stream_id + return messages, stream_pos + + return self.db_pool.runInteraction( + "get_new_device_msgs_for_remote", + get_new_messages_for_remote_destination_txn, + ) + + @trace + def delete_device_msgs_for_remote(self, destination, up_to_stream_id): + """Used to delete messages when the remote destination acknowledges + their receipt. + + Args: + destination(str): The destination server_name + up_to_stream_id(int): Where to delete messages up to. + Returns: + A deferred that resolves when the messages have been deleted. + """ + + def delete_messages_for_remote_destination_txn(txn): + sql = ( + "DELETE FROM device_federation_outbox" + " WHERE destination = ?" + " AND stream_id <= ?" + ) + txn.execute(sql, (destination, up_to_stream_id)) + + return self.db_pool.runInteraction( + "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn + ) + + async def get_all_new_device_messages( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for to device replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def get_all_new_device_messages_txn(txn): + # We limit like this as we might have multiple rows per stream_id, and + # we want to make sure we always get all entries for any stream_id + # we return. + upper_pos = min(current_id, last_id + limit) + sql = ( + "SELECT max(stream_id), user_id" + " FROM device_inbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY user_id" + ) + txn.execute(sql, (last_id, upper_pos)) + updates = [(row[0], row[1:]) for row in txn] + + sql = ( + "SELECT max(stream_id), destination" + " FROM device_federation_outbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY destination" + ) + txn.execute(sql, (last_id, upper_pos)) + updates.extend((row[0], row[1:]) for row in txn) + + # Order by ascending stream ordering + updates.sort() + + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db_pool.runInteraction( + "get_all_new_device_messages", get_all_new_device_messages_txn + ) + + +class DeviceInboxBackgroundUpdateStore(SQLBaseStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, database: DatabasePool, db_conn, hs): + super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + "device_inbox_stream_index", + index_name="device_inbox_stream_id_user_id", + table="device_inbox", + columns=["stream_id", "user_id"], + ) + + self.db_pool.updates.register_background_update_handler( + self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox + ) + + @defer.inlineCallbacks + def _background_drop_index_device_inbox(self, progress, batch_size): + def reindex_txn(conn): + txn = conn.cursor() + txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") + txn.close() + + yield self.db_pool.runWithConnection(reindex_txn) + + yield self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) + + return 1 + + +class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, database: DatabasePool, db_conn, hs): + super(DeviceInboxStore, self).__init__(database, db_conn, hs) + + # Map of (user_id, device_id) to the last stream_id that has been + # deleted up to. This is so that we can no op deletions. + self._last_device_delete_cache = ExpiringCache( + cache_name="last_device_delete_cache", + clock=self._clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + ) + + @trace + @defer.inlineCallbacks + def add_messages_to_device_inbox( + self, local_messages_by_user_then_device, remote_messages_by_destination + ): + """Used to send messages from this server. + + Args: + sender_user_id(str): The ID of the user sending these messages. + local_messages_by_user_and_device(dict): + Dictionary of user_id to device_id to message. + remote_messages_by_destination(dict): + Dictionary of destination server_name to the EDU JSON to send. + Returns: + A deferred stream_id that resolves when the messages have been + inserted. + """ + + def add_messages_txn(txn, now_ms, stream_id): + # Add the local messages directly to the local inbox. + self._add_messages_to_local_device_inbox_txn( + txn, stream_id, local_messages_by_user_then_device + ) + + # Add the remote messages to the federation outbox. + # We'll send them to a remote server when we next send a + # federation transaction to that destination. + sql = ( + "INSERT INTO device_federation_outbox" + " (destination, stream_id, queued_ts, messages_json)" + " VALUES (?,?,?,?)" + ) + rows = [] + for destination, edu in remote_messages_by_destination.items(): + edu_json = json.dumps(edu) + rows.append((destination, stream_id, now_ms, edu_json)) + txn.executemany(sql, rows) + + with self._device_inbox_id_gen.get_next() as stream_id: + now_ms = self.clock.time_msec() + yield self.db_pool.runInteraction( + "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id + ) + for user_id in local_messages_by_user_then_device.keys(): + self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id) + for destination in remote_messages_by_destination.keys(): + self._device_federation_outbox_stream_cache.entity_has_changed( + destination, stream_id + ) + + return self._device_inbox_id_gen.get_current_token() + + @defer.inlineCallbacks + def add_messages_from_remote_to_device_inbox( + self, origin, message_id, local_messages_by_user_then_device + ): + def add_messages_txn(txn, now_ms, stream_id): + # Check if we've already inserted a matching message_id for that + # origin. This can happen if the origin doesn't receive our + # acknowledgement from the first time we received the message. + already_inserted = self.db_pool.simple_select_one_txn( + txn, + table="device_federation_inbox", + keyvalues={"origin": origin, "message_id": message_id}, + retcols=("message_id",), + allow_none=True, + ) + if already_inserted is not None: + return + + # Add an entry for this message_id so that we know we've processed + # it. + self.db_pool.simple_insert_txn( + txn, + table="device_federation_inbox", + values={ + "origin": origin, + "message_id": message_id, + "received_ts": now_ms, + }, + ) + + # Add the messages to the approriate local device inboxes so that + # they'll be sent to the devices when they next sync. + self._add_messages_to_local_device_inbox_txn( + txn, stream_id, local_messages_by_user_then_device + ) + + with self._device_inbox_id_gen.get_next() as stream_id: + now_ms = self.clock.time_msec() + yield self.db_pool.runInteraction( + "add_messages_from_remote_to_device_inbox", + add_messages_txn, + now_ms, + stream_id, + ) + for user_id in local_messages_by_user_then_device.keys(): + self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id) + + return stream_id + + def _add_messages_to_local_device_inbox_txn( + self, txn, stream_id, messages_by_user_then_device + ): + local_by_user_then_device = {} + for user_id, messages_by_device in messages_by_user_then_device.items(): + messages_json_for_user = {} + devices = list(messages_by_device.keys()) + if len(devices) == 1 and devices[0] == "*": + # Handle wildcard device_ids. + sql = "SELECT device_id FROM devices WHERE user_id = ?" + txn.execute(sql, (user_id,)) + message_json = json.dumps(messages_by_device["*"]) + for row in txn: + # Add the message for all devices for this user on this + # server. + device = row[0] + messages_json_for_user[device] = message_json + else: + if not devices: + continue + + clause, args = make_in_list_sql_clause( + txn.database_engine, "device_id", devices + ) + sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause + + # TODO: Maybe this needs to be done in batches if there are + # too many local devices for a given user. + txn.execute(sql, [user_id] + list(args)) + for row in txn: + # Only insert into the local inbox if the device exists on + # this server + device = row[0] + message_json = json.dumps(messages_by_device[device]) + messages_json_for_user[device] = message_json + + if messages_json_for_user: + local_by_user_then_device[user_id] = messages_json_for_user + + if not local_by_user_then_device: + return + + sql = ( + "INSERT INTO device_inbox" + " (user_id, device_id, stream_id, message_json)" + " VALUES (?,?,?,?)" + ) + rows = [] + for user_id, messages_by_device in local_by_user_then_device.items(): + for device_id, message_json in messages_by_device.items(): + rows.append((user_id, device_id, stream_id, message_json)) + + txn.executemany(sql, rows) -- cgit 1.5.1 From 4dd27e6d1125df83a754b5e0c2c14aaafc0ce837 Mon Sep 17 00:00:00 2001 From: David Vo Date: Fri, 7 Aug 2020 22:02:55 +1000 Subject: Reduce unnecessary whitespace in JSON. (#7372) --- changelog.d/7372.misc | 1 + synapse/http/server.py | 5 +++-- synapse/replication/tcp/commands.py | 5 +++-- synapse/rest/media/v1/preview_url_resource.py | 4 ++-- synapse/storage/databases/main/account_data.py | 7 +++---- synapse/storage/databases/main/deviceinbox.py | 9 ++++----- synapse/storage/databases/main/devices.py | 11 +++++------ synapse/storage/databases/main/e2e_room_keys.py | 11 +++++------ synapse/storage/databases/main/end_to_end_keys.py | 5 +++-- synapse/storage/databases/main/event_push_actions.py | 5 ++--- synapse/storage/databases/main/group_server.py | 17 ++++++++--------- synapse/storage/databases/main/push_rule.py | 9 ++++----- synapse/storage/databases/main/receipts.py | 9 ++++----- synapse/util/__init__.py | 4 ++++ synapse/util/frozenutils.py | 7 +++++-- 15 files changed, 56 insertions(+), 53 deletions(-) create mode 100644 changelog.d/7372.misc (limited to 'synapse/storage/databases/main/deviceinbox.py') diff --git a/changelog.d/7372.misc b/changelog.d/7372.misc new file mode 100644 index 0000000000..67a39f0471 --- /dev/null +++ b/changelog.d/7372.misc @@ -0,0 +1 @@ +Reduce the amount of whitespace in JSON stored and sent in responses. Contributed by David Vo. diff --git a/synapse/http/server.py b/synapse/http/server.py index 94ab29974a..ffe6cfa09e 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -25,7 +25,7 @@ from io import BytesIO from typing import Any, Callable, Dict, Tuple, Union import jinja2 -from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json +from canonicaljson import encode_canonical_json, encode_pretty_printed_json from twisted.internet import defer from twisted.python import failure @@ -46,6 +46,7 @@ from synapse.api.errors import ( from synapse.http.site import SynapseRequest from synapse.logging.context import preserve_fn from synapse.logging.opentracing import trace_servlet +from synapse.util import json_encoder from synapse.util.caches import intern_dict logger = logging.getLogger(__name__) @@ -538,7 +539,7 @@ def respond_with_json( # canonicaljson already encodes to bytes json_bytes = encode_canonical_json(json_object) else: - json_bytes = json.dumps(json_object).encode("utf-8") + json_bytes = json_encoder.encode(json_object).encode("utf-8") return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors) diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index f33801f883..d853e4447e 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -18,11 +18,12 @@ The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are allowed to be sent by which side. """ import abc -import json import logging from typing import Tuple, Type -_json_encoder = json.JSONEncoder() +from canonicaljson import json + +from synapse.util import json_encoder as _json_encoder logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index f4768a9e8b..4bb454c36f 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -27,7 +27,6 @@ from typing import Dict, Optional from urllib import parse as urlparse import attr -from canonicaljson import json from twisted.internet import defer from twisted.internet.error import DNSLookupError @@ -43,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers +from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.stringutils import random_string @@ -355,7 +355,7 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("Calculated OG for %s as %s", url, og) - jsonog = json.dumps(og) + jsonog = json_encoder.encode(og) # store OG in history-aware DB cache await self.store.store_url_cache( diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 2193d8fdc5..cf039e7f7d 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -18,13 +18,12 @@ import abc import logging from typing import List, Tuple -from canonicaljson import json - from twisted.internet import defer from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -327,7 +326,7 @@ class AccountDataStore(AccountDataWorkerStore): Returns: A deferred that completes once the account_data has been added. """ - content_json = json.dumps(content) + content_json = json_encoder.encode(content) with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint @@ -373,7 +372,7 @@ class AccountDataStore(AccountDataWorkerStore): Returns: A deferred that completes once the account_data has been added. """ - content_json = json.dumps(content) + content_json = json_encoder.encode(content) with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 874ecdf8d2..76ec954f44 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -16,13 +16,12 @@ import logging from typing import List, Tuple -from canonicaljson import json - from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool +from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache logger = logging.getLogger(__name__) @@ -354,7 +353,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) ) rows = [] for destination, edu in remote_messages_by_destination.items(): - edu_json = json.dumps(edu) + edu_json = json_encoder.encode(edu) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) @@ -432,7 +431,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Handle wildcard device_ids. sql = "SELECT device_id FROM devices WHERE user_id = ?" txn.execute(sql, (user_id,)) - message_json = json.dumps(messages_by_device["*"]) + message_json = json_encoder.encode(messages_by_device["*"]) for row in txn: # Add the message for all devices for this user on this # server. @@ -454,7 +453,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Only insert into the local inbox if the device exists on # this server device = row[0] - message_json = json.dumps(messages_by_device[device]) + message_json = json_encoder.encode(messages_by_device[device]) messages_json_for_user[device] = message_json if messages_json_for_user: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 88a7aadfc6..81e64de126 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -17,8 +17,6 @@ import logging from typing import List, Optional, Set, Tuple -from canonicaljson import json - from twisted.internet import defer from synapse.api.errors import Codes, StoreError @@ -36,6 +34,7 @@ from synapse.storage.database import ( make_tuple_comparison_clause, ) from synapse.types import Collection, get_verify_key_from_cross_signing_key +from synapse.util import json_encoder from synapse.util.caches.descriptors import ( Cache, cached, @@ -397,7 +396,7 @@ class DeviceWorkerStore(SQLBaseStore): values={ "stream_id": stream_id, "from_user_id": from_user_id, - "user_ids": json.dumps(user_ids), + "user_ids": json_encoder.encode(user_ids), }, ) @@ -1032,7 +1031,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, - values={"content": json.dumps(content)}, + values={"content": json_encoder.encode(content)}, # we don't need to lock, because we assume we are the only thread # updating this user's devices. lock=False, @@ -1088,7 +1087,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): { "user_id": user_id, "device_id": content["device_id"], - "content": json.dumps(content), + "content": json_encoder.encode(content), } for content in devices ], @@ -1209,7 +1208,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): "device_id": device_id, "sent": False, "ts": now, - "opentracing_context": json.dumps(context) + "opentracing_context": json_encoder.encode(context) if whitelisted_homeserver(destination) else "{}", } diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 90152edc3c..c4aaec3993 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -14,13 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from canonicaljson import json - from twisted.internet import defer from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.util import json_encoder class EndToEndRoomKeyStore(SQLBaseStore): @@ -50,7 +49,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], - "session_data": json.dumps(room_key["session_data"]), + "session_data": json_encoder.encode(room_key["session_data"]), }, desc="update_e2e_room_key", ) @@ -77,7 +76,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], - "session_data": json.dumps(room_key["session_data"]), + "session_data": json_encoder.encode(room_key["session_data"]), } ) log_kv( @@ -360,7 +359,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "user_id": user_id, "version": new_version, "algorithm": info["algorithm"], - "auth_data": json.dumps(info["auth_data"]), + "auth_data": json_encoder.encode(info["auth_data"]), }, ) @@ -387,7 +386,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): updatevalues = {} if info is not None and "auth_data" in info: - updatevalues["auth_data"] = json.dumps(info["auth_data"]) + updatevalues["auth_data"] = json_encoder.encode(info["auth_data"]) if version_etag is not None: updatevalues["etag"] = version_etag diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 40354b8304..6126376a6f 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -16,7 +16,7 @@ # limitations under the License. from typing import Dict, List, Tuple -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.enterprise.adbapi import Connection from twisted.internet import defer @@ -24,6 +24,7 @@ from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import make_in_list_sql_clause +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -700,7 +701,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): values={ "user_id": user_id, "keytype": key_type, - "keydata": json.dumps(key), + "keydata": json_encoder.encode(key), "stream_id": stream_id, }, ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index b8cefb4d5e..7c246d3e4c 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -17,11 +17,10 @@ import logging from typing import List -from canonicaljson import json - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool +from synapse.util import json_encoder from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -50,7 +49,7 @@ def _serialize_action(actions, is_highlight): else: if actions == DEFAULT_NOTIF_ACTION: return "" - return json.dumps(actions) + return json_encoder.encode(actions) def _deserialize_action(actions, is_highlight): diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index a98181f445..75ea6d4b2f 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -16,12 +16,11 @@ from typing import List, Tuple -from canonicaljson import json - from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.util import json_encoder # The category ID for the "default" category. We don't store as null in the # database to avoid the fun of null != null @@ -752,7 +751,7 @@ class GroupServerStore(GroupServerWorkerStore): if profile is None: insertion_values["profile"] = "{}" else: - update_values["profile"] = json.dumps(profile) + update_values["profile"] = json_encoder.encode(profile) if is_public is None: insertion_values["is_public"] = True @@ -783,7 +782,7 @@ class GroupServerStore(GroupServerWorkerStore): if profile is None: insertion_values["profile"] = "{}" else: - update_values["profile"] = json.dumps(profile) + update_values["profile"] = json_encoder.encode(profile) if is_public is None: insertion_values["is_public"] = True @@ -1007,7 +1006,7 @@ class GroupServerStore(GroupServerWorkerStore): "group_id": group_id, "user_id": user_id, "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json.dumps(remote_attestation), + "attestation_json": json_encoder.encode(remote_attestation), }, ) @@ -1131,7 +1130,7 @@ class GroupServerStore(GroupServerWorkerStore): "is_admin": is_admin, "membership": membership, "is_publicised": is_publicised, - "content": json.dumps(content), + "content": json_encoder.encode(content), }, ) @@ -1143,7 +1142,7 @@ class GroupServerStore(GroupServerWorkerStore): "group_id": group_id, "user_id": user_id, "type": "membership", - "content": json.dumps( + "content": json_encoder.encode( {"membership": membership, "content": content} ), }, @@ -1171,7 +1170,7 @@ class GroupServerStore(GroupServerWorkerStore): "group_id": group_id, "user_id": user_id, "valid_until_ms": remote_attestation["valid_until_ms"], - "attestation_json": json.dumps(remote_attestation), + "attestation_json": json_encoder.encode(remote_attestation), }, ) else: @@ -1240,7 +1239,7 @@ class GroupServerStore(GroupServerWorkerStore): keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ "valid_until_ms": attestation["valid_until_ms"], - "attestation_json": json.dumps(attestation), + "attestation_json": json_encoder.encode(attestation), }, desc="update_remote_attestion", ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 97cc12931d..264521635f 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -18,8 +18,6 @@ import abc import logging from typing import List, Tuple, Union -from canonicaljson import json - from twisted.internet import defer from synapse.push.baserules import list_with_base_rules @@ -33,6 +31,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.util.id_generators import ChainedIdGenerator +from synapse.util import json_encoder from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -411,8 +410,8 @@ class PushRuleStore(PushRulesWorkerStore): before=None, after=None, ): - conditions_json = json.dumps(conditions) - actions_json = json.dumps(actions) + conditions_json = json_encoder.encode(conditions) + actions_json = json_encoder.encode(actions) with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids if before or after: @@ -681,7 +680,7 @@ class PushRuleStore(PushRulesWorkerStore): @defer.inlineCallbacks def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): - actions_json = json.dumps(actions) + actions_json = json_encoder.encode(actions) def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): if is_default_rule: diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 6255977c92..1920a8a152 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -18,13 +18,12 @@ import abc import logging from typing import List, Tuple -from canonicaljson import json - from twisted.internet import defer from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -459,7 +458,7 @@ class ReceiptsStore(ReceiptsWorkerStore): values={ "stream_id": stream_id, "event_id": event_id, - "data": json.dumps(data), + "data": json_encoder.encode(data), }, # receipts_linearized has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock @@ -585,7 +584,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, - "event_ids": json.dumps(event_ids), - "data": json.dumps(data), + "event_ids": json_encoder.encode(event_ids), + "data": json_encoder.encode(data), }, ) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index c63256d3bd..b3f76428b6 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -17,6 +17,7 @@ import logging import re import attr +from canonicaljson import json from twisted.internet import defer, task @@ -24,6 +25,9 @@ from synapse.logging import context logger = logging.getLogger(__name__) +# Create a custom encoder to reduce the whitespace produced by JSON encoding. +json_encoder = json.JSONEncoder(separators=(",", ":")) + def unwrapFirstError(failure): # defer.gatherResults and DeferredLists wrap failures. diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index eab78dd256..0e445e01d7 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -63,5 +63,8 @@ def _handle_frozendict(obj): ) -# A JSONEncoder which is capable of encoding frozendicts without barfing -frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict) +# A JSONEncoder which is capable of encoding frozendicts without barfing. +# Additionally reduce the whitespace produced by JSON encoding. +frozendict_json_encoder = json.JSONEncoder( + default=_handle_frozendict, separators=(",", ":"), +) -- cgit 1.5.1 From d68e10f308f89810e8d9ff94219cc68ca83f636d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Aug 2020 09:29:06 -0400 Subject: Convert account data, device inbox, and censor events databases to async/await (#8063) --- changelog.d/8063.misc | 1 + synapse/storage/databases/main/account_data.py | 77 +++++++++++--------- synapse/storage/databases/main/censor_events.py | 11 ++- synapse/storage/databases/main/deviceinbox.py | 94 +++++++++++++------------ tests/handlers/test_typing.py | 3 +- 5 files changed, 99 insertions(+), 87 deletions(-) create mode 100644 changelog.d/8063.misc (limited to 'synapse/storage/databases/main/deviceinbox.py') diff --git a/changelog.d/8063.misc b/changelog.d/8063.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8063.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index cf039e7f7d..82aac2bbf3 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -16,15 +16,16 @@ import abc import logging -from typing import List, Tuple +from typing import List, Optional, Tuple from twisted.internet import defer from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.types import JsonDict from synapse.util import json_encoder -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) @@ -97,13 +98,15 @@ class AccountDataWorkerStore(SQLBaseStore): "get_account_data_for_user", get_account_data_for_user_txn ) - @cachedInlineCallbacks(num_args=2, max_entries=5000) - def get_global_account_data_by_type_for_user(self, data_type, user_id): + @cached(num_args=2, max_entries=5000) + async def get_global_account_data_by_type_for_user( + self, data_type: str, user_id: str + ) -> Optional[JsonDict]: """ Returns: - Deferred: A dict + The account data. """ - result = yield self.db_pool.simple_select_one_onecol( + result = await self.db_pool.simple_select_one_onecol( table="account_data", keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", @@ -280,9 +283,11 @@ class AccountDataWorkerStore(SQLBaseStore): "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) - @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) - def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): - ignored_account_data = yield self.get_global_account_data_by_type_for_user( + @cached(num_args=2, cache_context=True, max_entries=5000) + async def is_ignored_by( + self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext + ) -> bool: + ignored_account_data = await self.get_global_account_data_by_type_for_user( "m.ignored_user_list", ignorer_user_id, on_invalidate=cache_context.invalidate, @@ -307,24 +312,27 @@ class AccountDataStore(AccountDataWorkerStore): super(AccountDataStore, self).__init__(database, db_conn, hs) - def get_max_account_data_stream_id(self): + def get_max_account_data_stream_id(self) -> int: """Get the current max stream id for the private user data stream Returns: - A deferred int. + The maximum stream ID. """ return self._account_data_id_gen.get_current_token() - @defer.inlineCallbacks - def add_account_data_to_room(self, user_id, room_id, account_data_type, content): + async def add_account_data_to_room( + self, user_id: str, room_id: str, account_data_type: str, content: JsonDict + ) -> int: """Add some account_data to a room for a user. + Args: - user_id(str): The user to add a tag for. - room_id(str): The room to add a tag for. - account_data_type(str): The type of account_data to add. - content(dict): A json object to associate with the tag. + user_id: The user to add a tag for. + room_id: The room to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + Returns: - A deferred that completes once the account_data has been added. + The maximum stream ID. """ content_json = json_encoder.encode(content) @@ -332,7 +340,7 @@ class AccountDataStore(AccountDataWorkerStore): # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. - yield self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ @@ -350,7 +358,7 @@ class AccountDataStore(AccountDataWorkerStore): # doesn't sound any worse than the whole update getting lost, # which is what would happen if we combined the two into one # transaction. - yield self._update_max_stream_id(next_id) + await self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) @@ -359,18 +367,20 @@ class AccountDataStore(AccountDataWorkerStore): (user_id, room_id, account_data_type), content ) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - @defer.inlineCallbacks - def add_account_data_for_user(self, user_id, account_data_type, content): + async def add_account_data_for_user( + self, user_id: str, account_data_type: str, content: JsonDict + ) -> int: """Add some account_data to a room for a user. + Args: - user_id(str): The user to add a tag for. - account_data_type(str): The type of account_data to add. - content(dict): A json object to associate with the tag. + user_id: The user to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + Returns: - A deferred that completes once the account_data has been added. + The maximum stream ID. """ content_json = json_encoder.encode(content) @@ -378,7 +388,7 @@ class AccountDataStore(AccountDataWorkerStore): # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. - yield self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( desc="add_user_account_data", table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, @@ -396,7 +406,7 @@ class AccountDataStore(AccountDataWorkerStore): # Note: This is only here for backwards compat to allow admins to # roll back to a previous Synapse version. Next time we update the # database version we can remove this table. - yield self._update_max_stream_id(next_id) + await self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) @@ -404,14 +414,13 @@ class AccountDataStore(AccountDataWorkerStore): (account_data_type, user_id) ) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - def _update_max_stream_id(self, next_id): + def _update_max_stream_id(self, next_id: int): """Update the max stream_id Args: - next_id(int): The the revision to advance to. + next_id: The the revision to advance to. """ # Note: This is only here for backwards compat to allow admins to diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 1de8249563..f211ddbaf8 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -16,8 +16,6 @@ import logging from typing import TYPE_CHECKING -from twisted.internet import defer - from synapse.events.utils import prune_event_dict from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore @@ -148,17 +146,16 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase updatevalues={"json": pruned_json}, ) - @defer.inlineCallbacks - def expire_event(self, event_id): + async def expire_event(self, event_id: str) -> None: """Retrieve and expire an event that has expired, and delete its associated expiry timestamp. If the event can't be retrieved, delete its associated timestamp so we don't try to expire it again in the future. Args: - event_id (str): The ID of the event to delete. + event_id: The ID of the event to delete. """ # Try to retrieve the event's content from the database or the event cache. - event = yield self.get_event(event_id) + event = await self.get_event(event_id) def delete_expired_event_txn(txn): # Delete the expiry timestamp associated with this event from the database. @@ -193,7 +190,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase txn, "_get_event_cache", (event.event_id,) ) - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_expired_event", delete_expired_event_txn ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 76ec954f44..1f6e995c4f 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -16,8 +16,6 @@ import logging from typing import List, Tuple -from twisted.internet import defer - from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool @@ -31,24 +29,31 @@ class DeviceInboxWorkerStore(SQLBaseStore): def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() - def get_new_messages_for_device( - self, user_id, device_id, last_stream_id, current_stream_id, limit=100 - ): + async def get_new_messages_for_device( + self, + user_id: str, + device_id: str, + last_stream_id: int, + current_stream_id: int, + limit: int = 100, + ) -> Tuple[List[dict], int]: """ Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - current_stream_id(int): The current position of the to device + user_id: The recipient user_id. + device_id: The recipient device_id. + last_stream_id: The last stream ID checked. + current_stream_id: The current position of the to device message stream. + limit: The maximum number of messages to retrieve. + Returns: - Deferred ([dict], int): List of messages for the device and where - in the stream the messages got to. + A list of messages for the device and where in the stream the messages got to. """ has_changed = self._device_inbox_stream_cache.has_entity_changed( user_id, last_stream_id ) if not has_changed: - return defer.succeed(([], current_stream_id)) + return ([], current_stream_id) def get_new_messages_for_device_txn(txn): sql = ( @@ -69,20 +74,22 @@ class DeviceInboxWorkerStore(SQLBaseStore): stream_pos = current_stream_id return messages, stream_pos - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_new_messages_for_device", get_new_messages_for_device_txn ) @trace - @defer.inlineCallbacks - def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): + async def delete_messages_for_device( + self, user_id: str, device_id: str, up_to_stream_id: int + ) -> int: """ Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - up_to_stream_id(int): Where to delete messages up to. + user_id: The recipient user_id. + device_id: The recipient device_id. + up_to_stream_id: Where to delete messages up to. + Returns: - A deferred that resolves to the number of messages deleted. + The number of messages deleted. """ # If we have cached the last stream id we've deleted up to, we can # check if there is likely to be anything that needs deleting @@ -109,7 +116,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount - count = yield self.db_pool.runInteraction( + count = await self.db_pool.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn ) @@ -128,9 +135,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): return count @trace - def get_new_device_msgs_for_remote( + async def get_new_device_msgs_for_remote( self, destination, last_stream_id, current_stream_id, limit - ): + ) -> Tuple[List[dict], int]: """ Args: destination(str): The name of the remote server. @@ -139,8 +146,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): current_stream_id(int|long): The current position of the device message stream. Returns: - Deferred ([dict], int|long): List of messages for the device and where - in the stream the messages got to. + A list of messages for the device and where in the stream the messages got to. """ set_tag("destination", destination) @@ -153,11 +159,11 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) if not has_changed or last_stream_id == current_stream_id: log_kv({"message": "No new messages in stream"}) - return defer.succeed(([], current_stream_id)) + return ([], current_stream_id) if limit <= 0: # This can happen if we run out of room for EDUs in the transaction. - return defer.succeed(([], last_stream_id)) + return ([], last_stream_id) @trace def get_new_messages_for_remote_destination_txn(txn): @@ -178,7 +184,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): stream_pos = current_stream_id return messages, stream_pos - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_new_device_msgs_for_remote", get_new_messages_for_remote_destination_txn, ) @@ -290,16 +296,15 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) - @defer.inlineCallbacks - def _background_drop_index_device_inbox(self, progress, batch_size): + async def _background_drop_index_device_inbox(self, progress, batch_size): def reindex_txn(conn): txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") txn.close() - yield self.db_pool.runWithConnection(reindex_txn) + await self.db_pool.runWithConnection(reindex_txn) - yield self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) + await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) return 1 @@ -320,21 +325,21 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) ) @trace - @defer.inlineCallbacks - def add_messages_to_device_inbox( - self, local_messages_by_user_then_device, remote_messages_by_destination - ): + async def add_messages_to_device_inbox( + self, + local_messages_by_user_then_device: dict, + remote_messages_by_destination: dict, + ) -> int: """Used to send messages from this server. Args: - sender_user_id(str): The ID of the user sending these messages. - local_messages_by_user_and_device(dict): + local_messages_by_user_and_device: Dictionary of user_id to device_id to message. - remote_messages_by_destination(dict): + remote_messages_by_destination: Dictionary of destination server_name to the EDU JSON to send. + Returns: - A deferred stream_id that resolves when the messages have been - inserted. + The new stream_id. """ def add_messages_txn(txn, now_ms, stream_id): @@ -359,7 +364,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id ) for user_id in local_messages_by_user_then_device.keys(): @@ -371,10 +376,9 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) return self._device_inbox_id_gen.get_current_token() - @defer.inlineCallbacks - def add_messages_from_remote_to_device_inbox( - self, origin, message_id, local_messages_by_user_then_device - ): + async def add_messages_from_remote_to_device_inbox( + self, origin: str, message_id: str, local_messages_by_user_then_device: dict + ) -> int: def add_messages_txn(txn, now_ms, stream_id): # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our @@ -409,7 +413,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, now_ms, diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index b7d0adb10e..64ddd8243d 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -24,6 +24,7 @@ from synapse.api.errors import AuthError from synapse.types import UserID from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import register_federation_servlets @@ -151,7 +152,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_current_state_deltas.return_value = (0, None) self.datastore.get_to_device_stream_token = lambda: 0 - self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed( + self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable( ([], 0) ) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None -- cgit 1.5.1 From 2231dffee6788836c86e868dd29574970b13dd18 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 25 Aug 2020 15:10:08 +0100 Subject: Make StreamIdGen `get_next` and `get_next_mult` async (#8161) This is mainly so that `StreamIdGenerator` and `MultiWriterIdGenerator` will have the same interface, allowing them to be used interchangeably. --- changelog.d/8161.misc | 1 + synapse/storage/databases/main/account_data.py | 4 +-- synapse/storage/databases/main/deviceinbox.py | 4 +-- synapse/storage/databases/main/devices.py | 8 +++-- synapse/storage/databases/main/end_to_end_keys.py | 43 ++++++++++++----------- synapse/storage/databases/main/events.py | 4 +-- synapse/storage/databases/main/group_server.py | 2 +- synapse/storage/databases/main/presence.py | 2 +- synapse/storage/databases/main/push_rule.py | 8 ++--- synapse/storage/databases/main/pusher.py | 4 +-- synapse/storage/databases/main/receipts.py | 3 +- synapse/storage/databases/main/room.py | 6 ++-- synapse/storage/databases/main/tags.py | 4 +-- synapse/storage/util/id_generators.py | 10 +++--- 14 files changed, 54 insertions(+), 49 deletions(-) create mode 100644 changelog.d/8161.misc (limited to 'synapse/storage/databases/main/deviceinbox.py') diff --git a/changelog.d/8161.misc b/changelog.d/8161.misc new file mode 100644 index 0000000000..89ff274de3 --- /dev/null +++ b/changelog.d/8161.misc @@ -0,0 +1 @@ +Refactor `StreamIdGenerator` and `MultiWriterIdGenerator` to have the same interface. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 82aac2bbf3..04042a2c98 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. @@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 1f6e995c4f..bb85637a95 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) - with self._device_inbox_id_gen.get_next() as stream_id: + with await self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id @@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) txn, stream_id, local_messages_by_user_then_device ) - with self._device_inbox_id_gen.get_next() as stream_id: + with await self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 9a786e2929..03b45dbc4d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore): THe new stream ID. """ - with self._device_list_id_gen.get_next() as stream_id: + with await self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -1146,7 +1146,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return - with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: + with await self._device_list_id_gen.get_next_mult( + len(device_ids) + ) as stream_ids: await self.db_pool.runInteraction( "add_device_change_to_stream", self._add_device_change_to_stream_txn, @@ -1159,7 +1161,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return stream_ids[-1] context = get_active_span_text_map() - with self._device_list_id_gen.get_next_mult( + with await self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index f93e0d320d..385868bdab 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) - def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): + def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id): """Set a user's cross-signing key. Args: @@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key key (dict): the key data + stream_id (int) """ # the 'key' dict will look something like: # { @@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) # and finally, store the key itself - with self._cross_signing_id_gen.get_next() as stream_id: - self.db_pool.simple_insert_txn( - txn, - "e2e_cross_signing_keys", - values={ - "user_id": user_id, - "keytype": key_type, - "keydata": json_encoder.encode(key), - "stream_id": stream_id, - }, - ) + self.db_pool.simple_insert_txn( + txn, + "e2e_cross_signing_keys", + values={ + "user_id": user_id, + "keytype": key_type, + "keydata": json_encoder.encode(key), + "stream_id": stream_id, + }, + ) self._invalidate_cache_and_stream( txn, self._get_bare_e2e_cross_signing_keys, (user_id,) ) - def set_e2e_cross_signing_key(self, user_id, key_type, key): + async def set_e2e_cross_signing_key(self, user_id, key_type, key): """Set a user's cross-signing key. Args: @@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key_type (str): the type of cross-signing key to set key (dict): the key data """ - return self.db_pool.runInteraction( - "add_e2e_cross_signing_key", - self._set_e2e_cross_signing_key_txn, - user_id, - key_type, - key, - ) + + with await self._cross_signing_id_gen.get_next() as stream_id: + return await self.db_pool.runInteraction( + "add_e2e_cross_signing_key", + self._set_e2e_cross_signing_key_txn, + user_id, + key_type, + key, + stream_id, + ) def store_e2e_cross_signing_signatures(self, user_id, signatures): """Stores cross-signing signatures. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b90e6de2d5..6313b41eef 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -153,11 +153,11 @@ class PersistEventsStore: # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. if backfilled: - stream_ordering_manager = self._backfill_id_gen.get_next_mult( + stream_ordering_manager = await self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) else: - stream_ordering_manager = self._stream_id_gen.get_next_mult( + stream_ordering_manager = await self._stream_id_gen.get_next_mult( len(events_and_contexts) ) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 0e3b8739c6..a488e0924b 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -1182,7 +1182,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - with self._group_updates_id_gen.get_next() as next_id: + with await self._group_updates_id_gen.get_next() as next_id: res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 4e3ec02d14..c9f655dfb7 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -23,7 +23,7 @@ from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): async def update_presence(self, presence_states): - stream_ordering_manager = self._presence_id_gen.get_next_mult( + stream_ordering_manager = await self._presence_id_gen.get_next_mult( len(presence_states) ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index a585e54812..2fb5b02d7d 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore): ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() if before or after: @@ -560,7 +560,7 @@ class PushRuleStore(PushRulesWorkerStore): txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore): ) async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None: - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -646,7 +646,7 @@ class PushRuleStore(PushRulesWorkerStore): data={"actions": actions_json}, ) - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 1126fd0751..c388468273 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore): last_stream_ordering, profile_tag="", ) -> None: - with self._pushers_id_gen.get_next() as stream_id: + with await self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry await self.db_pool.simple_upsert( @@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore): }, ) - with self._pushers_id_gen.get_next() as stream_id: + with await self._pushers_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "delete_pusher", delete_pusher_txn, stream_id ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 19ad1c056f..6821476ee0 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -520,8 +520,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "insert_receipt_conv", graph_to_linear ) - stream_id_manager = self._receipts_id_gen.get_next() - with stream_id_manager as stream_id: + with await self._receipts_id_gen.get_next() as stream_id: event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 7d3ac47261..b3772be2b2 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1129,7 +1129,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "store_room_txn", store_room_txn, next_id ) @@ -1196,7 +1196,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) @@ -1276,7 +1276,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index ade7abc927..0c34bbf21a 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 0bf772d4d1..ddb5c8c60c 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -80,7 +80,7 @@ class StreamIdGenerator(object): upwards, -1 to grow downwards. Usage: - with stream_id_gen.get_next() as stream_id: + with await stream_id_gen.get_next() as stream_id: # ... persist event ... """ @@ -95,10 +95,10 @@ class StreamIdGenerator(object): ) self._unfinished_ids = deque() # type: Deque[int] - def get_next(self): + async def get_next(self): """ Usage: - with stream_id_gen.get_next() as stream_id: + with await stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -117,10 +117,10 @@ class StreamIdGenerator(object): return manager() - def get_next_mult(self, n): + async def get_next_mult(self, n): """ Usage: - with stream_id_gen.get_next(n) as stream_ids: + with await stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: -- cgit 1.5.1 From 5c03134d0f8dd157ea1800ce1a4bcddbdb73ddf1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Aug 2020 07:54:27 -0400 Subject: Convert additional database code to async/await. (#8195) --- changelog.d/8195.misc | 1 + synapse/appservice/__init__.py | 19 ++- synapse/federation/persistence.py | 19 ++- synapse/handlers/federation.py | 4 +- synapse/storage/databases/main/appservice.py | 15 +- synapse/storage/databases/main/deviceinbox.py | 12 +- synapse/storage/databases/main/e2e_room_keys.py | 30 ++-- synapse/storage/databases/main/event_federation.py | 71 ++++---- synapse/storage/databases/main/group_server.py | 187 +++++++++++++-------- synapse/storage/databases/main/keys.py | 24 +-- synapse/storage/databases/main/transactions.py | 39 +++-- 11 files changed, 246 insertions(+), 175 deletions(-) create mode 100644 changelog.d/8195.misc (limited to 'synapse/storage/databases/main/deviceinbox.py') diff --git a/changelog.d/8195.misc b/changelog.d/8195.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8195.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 1ffdc1ed95..69a7182ef4 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -14,11 +14,16 @@ # limitations under the License. import logging import re +from typing import TYPE_CHECKING from synapse.api.constants import EventTypes +from synapse.appservice.api import ApplicationServiceApi from synapse.types import GroupID, get_domain_from_id from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.storage.databases.main import DataStore + logger = logging.getLogger(__name__) @@ -35,19 +40,19 @@ class AppServiceTransaction(object): self.id = id self.events = events - def send(self, as_api): + async def send(self, as_api: ApplicationServiceApi) -> bool: """Sends this transaction using the provided AS API interface. Args: - as_api(ApplicationServiceApi): The API to use to send. + as_api: The API to use to send. Returns: - An Awaitable which resolves to True if the transaction was sent. + True if the transaction was sent. """ - return as_api.push_bulk( + return await as_api.push_bulk( service=self.service, events=self.events, txn_id=self.id ) - def complete(self, store): + async def complete(self, store: "DataStore") -> None: """Completes this transaction as successful. Marks this transaction ID on the application service and removes the @@ -55,10 +60,8 @@ class AppServiceTransaction(object): Args: store: The database store to operate on. - Returns: - A Deferred which resolves to True if the transaction was completed. """ - return store.complete_appservice_txn(service=self.service, txn_id=self.id) + await store.complete_appservice_txn(service=self.service, txn_id=self.id) class ApplicationService(object): diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 769cd5de28..de1fe7da38 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -20,6 +20,7 @@ These actions are mostly only used by the :py:mod:`.replication` module. """ import logging +from typing import Optional, Tuple from synapse.federation.units import Transaction from synapse.logging.utils import log_function @@ -36,25 +37,27 @@ class TransactionActions(object): self.store = datastore @log_function - def have_responded(self, origin, transaction): - """ Have we already responded to a transaction with the same id and + async def have_responded( + self, origin: str, transaction: Transaction + ) -> Optional[Tuple[int, JsonDict]]: + """Have we already responded to a transaction with the same id and origin? Returns: - Deferred: Results in `None` if we have not previously responded to - this transaction or a 2-tuple of `(int, dict)` representing the - response code and response body. + `None` if we have not previously responded to this transaction or a + 2-tuple of `(int, dict)` representing the response code and response body. """ - if not transaction.transaction_id: + transaction_id = transaction.transaction_id # type: ignore + if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") - return self.store.get_received_txn_response(transaction.transaction_id, origin) + return await self.store.get_received_txn_response(transaction_id, origin) @log_function async def set_response( self, origin: str, transaction: Transaction, code: int, response: JsonDict ) -> None: - """ Persist how we responded to a transaction. + """Persist how we responded to a transaction. """ transaction_id = transaction.transaction_id # type: ignore if not transaction_id: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 155d087413..16389a0dca 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1879,8 +1879,8 @@ class FederationHandler(BaseHandler): else: return None - def get_min_depth_for_context(self, context): - return self.store.get_min_depth(context) + async def get_min_depth_for_context(self, context): + return await self.store.get_min_depth(context) async def _handle_new_event( self, origin, event, state=None, auth_events=None, backfilled=False diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 92f56f1602..454c0bc50c 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -172,7 +172,7 @@ class ApplicationServiceTransactionWorkerStore( "application_services_state", {"as_id": service.id}, {"state": state} ) - def create_appservice_txn(self, service, events): + async def create_appservice_txn(self, service, events): """Atomically creates a new transaction for this application service with the given list of events. @@ -209,20 +209,17 @@ class ApplicationServiceTransactionWorkerStore( ) return AppServiceTransaction(service=service, id=new_txn_id, events=events) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "create_appservice_txn", _create_appservice_txn ) - def complete_appservice_txn(self, txn_id, service): + async def complete_appservice_txn(self, txn_id, service) -> None: """Completes an application service transaction. Args: txn_id(str): The transaction ID being completed. service(ApplicationService): The application service which was sent this transaction. - Returns: - A Deferred which resolves if this transaction was stored - successfully. """ txn_id = int(txn_id) @@ -258,7 +255,7 @@ class ApplicationServiceTransactionWorkerStore( {"txn_id": txn_id, "as_id": service.id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "complete_appservice_txn", _complete_appservice_txn ) @@ -312,13 +309,13 @@ class ApplicationServiceTransactionWorkerStore( else: return int(last_txn_id[0]) # select 'last_txn' col - def set_appservice_last_pos(self, pos): + async def set_appservice_last_pos(self, pos) -> None: def set_appservice_last_pos_txn(txn): txn.execute( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_appservice_last_pos", set_appservice_last_pos_txn ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index bb85637a95..0044433110 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -190,15 +190,15 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) @trace - def delete_device_msgs_for_remote(self, destination, up_to_stream_id): + async def delete_device_msgs_for_remote( + self, destination: str, up_to_stream_id: int + ) -> None: """Used to delete messages when the remote destination acknowledges their receipt. Args: - destination(str): The destination server_name - up_to_stream_id(int): Where to delete messages up to. - Returns: - A deferred that resolves when the messages have been deleted. + destination: The destination server_name + up_to_stream_id: Where to delete messages up to. """ def delete_messages_for_remote_destination_txn(txn): @@ -209,7 +209,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) txn.execute(sql, (destination, up_to_stream_id)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 82f9d870fd..12cecceec2 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -151,7 +151,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): return sessions - def get_e2e_room_keys_multi(self, user_id, version, room_keys): + async def get_e2e_room_keys_multi(self, user_id, version, room_keys): """Get multiple room keys at a time. The difference between this function and get_e2e_room_keys is that this function can be used to retrieve multiple specific keys at a time, whereas get_e2e_room_keys is used for @@ -166,10 +166,10 @@ class EndToEndRoomKeyStore(SQLBaseStore): that we want to query Returns: - Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key + dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_e2e_room_keys_multi", self._get_e2e_room_keys_multi_txn, user_id, @@ -283,7 +283,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): raise StoreError(404, "No current backup version") return row[0] - def get_e2e_room_keys_version_info(self, user_id, version=None): + async def get_e2e_room_keys_version_info(self, user_id, version=None): """Get info metadata about a version of our room_keys backup. Args: @@ -293,7 +293,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): Raises: StoreError: with code 404 if there are no e2e_room_keys_versions present Returns: - A deferred dict giving the info metadata for this backup version, with + A dict giving the info metadata for this backup version, with fields including: version(str) algorithm(str) @@ -324,12 +324,12 @@ class EndToEndRoomKeyStore(SQLBaseStore): result["etag"] = 0 return result - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn ) @trace - def create_e2e_room_keys_version(self, user_id, info): + async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str: """Atomically creates a new version of this user's e2e_room_keys store with the given version info. @@ -338,7 +338,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): info(dict): the info about the backup version to be created Returns: - A deferred string for the newly created version ID + The newly created version ID """ def _create_e2e_room_keys_version_txn(txn): @@ -365,7 +365,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): return new_version - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn ) @@ -403,13 +403,15 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - def delete_e2e_room_keys_version(self, user_id, version=None): + async def delete_e2e_room_keys_version( + self, user_id: str, version: Optional[str] = None + ) -> None: """Delete a given backup version of the user's room keys. Doesn't delete their actual key data. Args: - user_id(str): the user whose backup version we're deleting - version(str): Optional. the version ID of the backup version we're deleting + user_id: the user whose backup version we're deleting + version: Optional. the version ID of the backup version we're deleting If missing, we delete the current backup version info. Raises: StoreError: with code 404 if there are no e2e_room_keys_versions present, @@ -430,13 +432,13 @@ class EndToEndRoomKeyStore(SQLBaseStore): keyvalues={"user_id": user_id, "version": this_version}, ) - return self.db_pool.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version}, updatevalues={"deleted": 1}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 6e5761c7b7..0b69aa6a94 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -59,7 +59,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas include_given: include the given events in result Returns: - list of event_ids + An awaitable which resolve to a list of event_ids """ return await self.db_pool.runInteraction( "get_auth_chain_ids", @@ -95,7 +95,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return list(results) - def get_auth_chain_difference(self, state_sets: List[Set[str]]): + async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -104,10 +104,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas chain. Returns: - Deferred[Set[str]] + The set of the difference in auth chains. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_auth_chain_difference", self._get_auth_chain_difference_txn, state_sets, @@ -252,8 +252,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} - def get_oldest_events_with_depth_in_room(self, room_id): - return self.db_pool.runInteraction( + async def get_oldest_events_with_depth_in_room(self, room_id): + return await self.db_pool.runInteraction( "get_oldest_events_with_depth_in_room", self.get_oldest_events_with_depth_in_room_txn, room_id, @@ -293,7 +293,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas else: return max(row["depth"] for row in rows) - def get_prev_events_for_room(self, room_id: str): + async def get_prev_events_for_room(self, room_id: str) -> List[str]: """ Gets a subset of the current forward extremities in the given room. @@ -301,14 +301,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas events which refer to hundreds of prev_events. Args: - room_id (str): room_id + room_id: room_id Returns: - Deferred[List[str]]: the event ids of the forward extremites + The event ids of the forward extremities. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id ) @@ -328,17 +328,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return [row[0] for row in txn] - def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter): + async def get_rooms_with_many_extremities( + self, min_count: int, limit: int, room_id_filter: Iterable[str] + ) -> List[str]: """Get the top rooms with at least N extremities. Args: - min_count (int): The minimum number of extremities - limit (int): The maximum number of rooms to return. - room_id_filter (iterable[str]): room_ids to exclude from the results + min_count: The minimum number of extremities + limit: The maximum number of rooms to return. + room_id_filter: room_ids to exclude from the results Returns: - Deferred[list]: At most `limit` room IDs that have at least - `min_count` extremities, sorted by extremity count. + At most `limit` room IDs that have at least `min_count` extremities, + sorted by extremity count. """ def _get_rooms_with_many_extremities_txn(txn): @@ -363,7 +365,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(sql, query_args) return [room_id for room_id, in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn ) @@ -376,10 +378,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas desc="get_latest_event_ids_in_room", ) - def get_min_depth(self, room_id): - """ For hte given room, get the minimum depth we have seen for it. + async def get_min_depth(self, room_id: str) -> int: + """For the given room, get the minimum depth we have seen for it. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id ) @@ -394,7 +396,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return int(min_depth) if min_depth is not None else None - def get_forward_extremeties_for_room(self, room_id, stream_ordering): + async def get_forward_extremeties_for_room( + self, room_id: str, stream_ordering: int + ) -> List[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -402,11 +406,11 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas stream_orderings from that point. Args: - room_id (str): - stream_ordering (int): + room_id: + stream_ordering: Returns: - deferred, which resolves to a list of event_ids + A list of event_ids """ # We want to make the cache more effective, so we clamp to the last # change before the given ordering. @@ -422,10 +426,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas if last_change > self.stream_ordering_month_ago: stream_ordering = min(last_change, stream_ordering) - return self._get_forward_extremeties_for_room(room_id, stream_ordering) + return await self._get_forward_extremeties_for_room(room_id, stream_ordering) @cached(max_entries=5000, num_args=2) - def _get_forward_extremeties_for_room(self, room_id, stream_ordering): + async def _get_forward_extremeties_for_room(self, room_id, stream_ordering): """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -450,19 +454,18 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(sql, (stream_ordering, room_id)) return [event_id for event_id, in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - async def get_backfill_events(self, room_id, event_list, limit): + async def get_backfill_events(self, room_id: str, event_list: list, limit: int): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` Args: - txn - room_id (str) - event_list (list) - limit (int) + room_id + event_list + limit """ event_ids = await self.db_pool.runInteraction( "get_backfill_events", @@ -631,8 +634,8 @@ class EventFederationStore(EventFederationWorkerStore): _delete_old_forward_extrem_cache_txn, ) - def clean_room_for_join(self, room_id): - return self.db_pool.runInteraction( + async def clean_room_for_join(self, room_id): + return await self.db_pool.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id ) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 6c60171888..ccfbb2135e 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json @@ -70,7 +70,9 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_invited_users_in_group", ) - def get_rooms_in_group(self, group_id: str, include_private: bool = False): + async def get_rooms_in_group( + self, group_id: str, include_private: bool = False + ) -> List[Dict[str, Union[str, bool]]]: """Retrieve the rooms that belong to a given group. Does not return rooms that lack members. @@ -79,8 +81,7 @@ class GroupServerWorkerStore(SQLBaseStore): include_private: Whether to return private rooms in results Returns: - Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the - form of: + A list of dictionaries, each in the form of: { "room_id": "!a_room_id:example.com", # The ID of the room @@ -117,13 +118,13 @@ class GroupServerWorkerStore(SQLBaseStore): for room_id, is_public in txn ] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_in_group", _get_rooms_in_group_txn ) - def get_rooms_for_summary_by_category( + async def get_rooms_for_summary_by_category( self, group_id: str, include_private: bool = False, - ): + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: """Get the rooms and categories that should be included in a summary request Args: @@ -131,7 +132,7 @@ class GroupServerWorkerStore(SQLBaseStore): include_private: Whether to return private rooms in results Returns: - Deferred[Tuple[List, Dict]]: A tuple containing: + A tuple containing: * A list of dictionaries with the keys: * "room_id": str, the room ID @@ -207,7 +208,7 @@ class GroupServerWorkerStore(SQLBaseStore): return rooms, categories - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_for_summary", _get_rooms_for_summary_txn ) @@ -281,10 +282,11 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_local_groups_for_room", ) - def get_users_for_summary_by_role(self, group_id, include_private=False): + async def get_users_for_summary_by_role(self, group_id, include_private=False): """Get the users and roles that should be included in a summary request - Returns ([users], [roles]) + Returns: + ([users], [roles]) """ def _get_users_for_summary_txn(txn): @@ -338,7 +340,7 @@ class GroupServerWorkerStore(SQLBaseStore): return users, roles - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_for_summary_by_role", _get_users_for_summary_txn ) @@ -376,7 +378,7 @@ class GroupServerWorkerStore(SQLBaseStore): allow_none=True, ) - def get_users_membership_info_in_group(self, group_id, user_id): + async def get_users_membership_info_in_group(self, group_id, user_id): """Get a dict describing the membership of a user in a group. Example if joined: @@ -387,7 +389,8 @@ class GroupServerWorkerStore(SQLBaseStore): "is_privileged": False, } - Returns an empty dict if the user is not join/invite/etc + Returns: + An empty dict if the user is not join/invite/etc """ def _get_users_membership_in_group_txn(txn): @@ -419,7 +422,7 @@ class GroupServerWorkerStore(SQLBaseStore): return {} - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_membership_info_in_group", _get_users_membership_in_group_txn ) @@ -433,7 +436,7 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_publicised_groups_for_user", ) - def get_attestations_need_renewals(self, valid_until_ms): + async def get_attestations_need_renewals(self, valid_until_ms): """Get all attestations that need to be renewed until givent time """ @@ -445,7 +448,7 @@ class GroupServerWorkerStore(SQLBaseStore): txn.execute(sql, (valid_until_ms,)) return self.db_pool.cursor_to_dict(txn) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) @@ -475,7 +478,7 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_joined_groups", ) - def get_all_groups_for_user(self, user_id, now_token): + async def get_all_groups_for_user(self, user_id, now_token): def _get_all_groups_for_user_txn(txn): sql = """ SELECT group_id, type, membership, u.content @@ -495,7 +498,7 @@ class GroupServerWorkerStore(SQLBaseStore): for row in txn ] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_all_groups_for_user", _get_all_groups_for_user_txn ) @@ -600,8 +603,27 @@ class GroupServerStore(GroupServerWorkerStore): desc="set_group_join_policy", ) - def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): - return self.db_pool.runInteraction( + async def add_room_to_summary( + self, + group_id: str, + room_id: str, + category_id: str, + order: int, + is_public: Optional[bool], + ) -> None: + """Add (or update) room's entry in summary. + + Args: + group_id + room_id + category_id: If not None then adds the category to the end of + the summary if its not already there. + order: If not None inserts the room at that position, e.g. an order + of 1 will put the room first. Otherwise, the room gets added to + the end. + is_public + """ + await self.db_pool.runInteraction( "add_room_to_summary", self._add_room_to_summary_txn, group_id, @@ -612,18 +634,26 @@ class GroupServerStore(GroupServerWorkerStore): ) def _add_room_to_summary_txn( - self, txn, group_id, room_id, category_id, order, is_public - ): + self, + txn, + group_id: str, + room_id: str, + category_id: str, + order: int, + is_public: Optional[bool], + ) -> None: """Add (or update) room's entry in summary. Args: - group_id (str) - room_id (str) - category_id (str): If not None then adds the category to the end of - the summary if its not already there. [Optional] - order (int): If not None inserts the room at that position, e.g. - an order of 1 will put the room first. Otherwise, the room gets - added to the end. + txn + group_id + room_id + category_id: If not None then adds the category to the end of + the summary if its not already there. + order: If not None inserts the room at that position, e.g. an order + of 1 will put the room first. Otherwise, the room gets added to + the end. + is_public """ room_in_group = self.db_pool.simple_select_one_onecol_txn( txn, @@ -818,8 +848,27 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_group_role", ) - def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): - return self.db_pool.runInteraction( + async def add_user_to_summary( + self, + group_id: str, + user_id: str, + role_id: str, + order: int, + is_public: Optional[bool], + ) -> None: + """Add (or update) user's entry in summary. + + Args: + group_id + user_id + role_id: If not None then adds the role to the end of the summary if + its not already there. + order: If not None inserts the user at that position, e.g. an order + of 1 will put the user first. Otherwise, the user gets added to + the end. + is_public + """ + await self.db_pool.runInteraction( "add_user_to_summary", self._add_user_to_summary_txn, group_id, @@ -830,18 +879,26 @@ class GroupServerStore(GroupServerWorkerStore): ) def _add_user_to_summary_txn( - self, txn, group_id, user_id, role_id, order, is_public + self, + txn, + group_id: str, + user_id: str, + role_id: str, + order: int, + is_public: Optional[bool], ): """Add (or update) user's entry in summary. Args: - group_id (str) - user_id (str) - role_id (str): If not None then adds the role to the end of - the summary if its not already there. [Optional] - order (int): If not None inserts the user at that position, e.g. - an order of 1 will put the user first. Otherwise, the user gets - added to the end. + txn + group_id + user_id + role_id: If not None then adds the role to the end of the summary if + its not already there. + order: If not None inserts the user at that position, e.g. an order + of 1 will put the user first. Otherwise, the user gets added to + the end. + is_public """ user_in_group = self.db_pool.simple_select_one_onecol_txn( txn, @@ -963,27 +1020,26 @@ class GroupServerStore(GroupServerWorkerStore): desc="add_group_invite", ) - def add_user_to_group( + async def add_user_to_group( self, - group_id, - user_id, - is_admin=False, - is_public=True, - local_attestation=None, - remote_attestation=None, - ): + group_id: str, + user_id: str, + is_admin: bool = False, + is_public: bool = True, + local_attestation: dict = None, + remote_attestation: dict = None, + ) -> None: """Add a user to the group server. Args: - group_id (str) - user_id (str) - is_admin (bool) - is_public (bool) - local_attestation (dict): The attestation the GS created to give - to the remote server. Optional if the user and group are on the - same server - remote_attestation (dict): The attestation given to GS by remote + group_id + user_id + is_admin + is_public + local_attestation: The attestation the GS created to give to the remote server. Optional if the user and group are on the same server + remote_attestation: The attestation given to GS by remote server. + Optional if the user and group are on the same server """ def _add_user_to_group_txn(txn): @@ -1026,9 +1082,9 @@ class GroupServerStore(GroupServerWorkerStore): }, ) - return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) + await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) - def remove_user_from_group(self, group_id, user_id): + async def remove_user_from_group(self, group_id: str, user_id: str) -> None: def _remove_user_from_group_txn(txn): self.db_pool.simple_delete_txn( txn, @@ -1056,7 +1112,7 @@ class GroupServerStore(GroupServerWorkerStore): keyvalues={"group_id": group_id, "user_id": user_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "remove_user_from_group", _remove_user_from_group_txn ) @@ -1079,7 +1135,7 @@ class GroupServerStore(GroupServerWorkerStore): desc="update_room_in_group_visibility", ) - def remove_room_from_group(self, group_id, room_id): + async def remove_room_from_group(self, group_id: str, room_id: str) -> None: def _remove_room_from_group_txn(txn): self.db_pool.simple_delete_txn( txn, @@ -1093,7 +1149,7 @@ class GroupServerStore(GroupServerWorkerStore): keyvalues={"group_id": group_id, "room_id": room_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "remove_room_from_group", _remove_room_from_group_txn ) @@ -1286,14 +1342,11 @@ class GroupServerStore(GroupServerWorkerStore): def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() - def delete_group(self, group_id): + async def delete_group(self, group_id: str) -> None: """Deletes a group fully from the database. Args: - group_id (str) - - Returns: - Deferred + group_id: The group ID to delete. """ def _delete_group_txn(txn): @@ -1317,4 +1370,4 @@ class GroupServerStore(GroupServerWorkerStore): txn, table=table, keyvalues={"group_id": group_id} ) - return self.db_pool.runInteraction("delete_group", _delete_group_txn) + await self.db_pool.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 1c0a049c55..ad43bb05ab 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,7 +16,7 @@ import itertools import logging -from typing import Iterable, Tuple +from typing import Dict, Iterable, List, Optional, Tuple from signedjson.key import decode_verify_key_bytes @@ -42,16 +42,17 @@ class KeyStore(SQLBaseStore): @cachedList( cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" ) - def get_server_verify_keys(self, server_name_and_key_ids): + async def get_server_verify_keys( + self, server_name_and_key_ids: Iterable[Tuple[str, str]] + ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]: """ Args: - server_name_and_key_ids (iterable[Tuple[str, str]]): + server_name_and_key_ids: iterable of (server_name, key-id) tuples to fetch keys for Returns: - Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: - map from (server_name, key_id) -> FetchKeyResult, or None if the key is - unknown + A map from (server_name, key_id) -> FetchKeyResult, or None if the + key is unknown """ keys = {} @@ -87,7 +88,7 @@ class KeyStore(SQLBaseStore): _get_keys(txn, batch) return keys - return self.db_pool.runInteraction("get_server_verify_keys", _txn) + return await self.db_pool.runInteraction("get_server_verify_keys", _txn) async def store_server_verify_keys( self, @@ -179,7 +180,9 @@ class KeyStore(SQLBaseStore): desc="store_server_keys_json", ) - def get_server_keys_json(self, server_keys): + async def get_server_keys_json( + self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] + ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]: """Retrive the key json for a list of server_keys and key ids. If no keys are found for a given server, key_id and source then that server, key_id, and source triplet entry will be an empty list. @@ -188,8 +191,7 @@ class KeyStore(SQLBaseStore): Args: server_keys (list): List of (server_name, key_id, source) triplets. Returns: - Deferred[dict[Tuple[str, str, str|None], list[dict]]]: - Dict mapping (server_name, key_id, source) triplets to lists of dicts + A mapping from (server_name, key_id, source) triplets to a list of dicts """ def _get_server_keys_json_txn(txn): @@ -215,6 +217,6 @@ class KeyStore(SQLBaseStore): results[(server_name, key_id, from_server)] = rows return results - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_server_keys_json", _get_server_keys_json_txn ) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 2efcc0dc66..5b31aab700 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from typing import Optional, Tuple from canonicaljson import encode_canonical_json @@ -56,21 +57,23 @@ class TransactionStore(SQLBaseStore): expiry_ms=5 * 60 * 1000, ) - def get_received_txn_response(self, transaction_id, origin): + async def get_received_txn_response( + self, transaction_id: str, origin: str + ) -> Optional[Tuple[int, JsonDict]]: """For an incoming transaction from a given origin, check if we have already responded to it. If so, return the response code and response body (as a dict). Args: - transaction_id (str) - origin(str) + transaction_id + origin Returns: - tuple: None if we have not previously responded to - this transaction or a 2-tuple of (int, dict) + None if we have not previously responded to this transaction or a + 2-tuple of (int, dict) """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_received_txn_response", self._get_received_txn_response, transaction_id, @@ -166,21 +169,25 @@ class TransactionStore(SQLBaseStore): else: return None - def set_destination_retry_timings( - self, destination, failure_ts, retry_last_ts, retry_interval - ): + async def set_destination_retry_timings( + self, + destination: str, + failure_ts: Optional[int], + retry_last_ts: int, + retry_interval: int, + ) -> None: """Sets the current retry timings for a given destination. Both timings should be zero if retrying is no longer occuring. Args: - destination (str) - failure_ts (int|None) - when the server started failing (ms since epoch) - retry_last_ts (int) - time of last retry attempt in unix epoch ms - retry_interval (int) - how long until next retry in ms + destination + failure_ts: when the server started failing (ms since epoch) + retry_last_ts: time of last retry attempt in unix epoch ms + retry_interval: how long until next retry in ms """ self._destination_retry_cache.pop(destination, None) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, destination, @@ -256,13 +263,13 @@ class TransactionStore(SQLBaseStore): "cleanup_transactions", self._cleanup_transactions ) - def _cleanup_transactions(self): + async def _cleanup_transactions(self) -> None: now = self._clock.time_msec() month_ago = now - 30 * 24 * 60 * 60 * 1000 def _cleanup_transactions_txn(txn): txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_cleanup_transactions", _cleanup_transactions_txn ) -- cgit 1.5.1