From 2794b79052f96b8103ce2b710959853313a82e90 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 24 Oct 2019 11:48:46 +0100 Subject: Option to suppress resource exceeded alerting (#6173) The expected use case is to suppress MAU limiting on small instances --- .../test_resource_limits_server_notices.py | 59 +++++++++++++++++++++- tests/utils.py | 1 - 2 files changed, 57 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index cdf89e3383..eb540e34f6 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -17,7 +17,7 @@ from mock import Mock from twisted.internet import defer -from synapse.api.constants import EventTypes, ServerNoticeMsgType +from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, @@ -133,7 +133,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check contents, but 2 calls == set blocking event - self.assertTrue(self._send_notice.call_count == 2) + self.assertEqual(self._send_notice.call_count, 2) def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): """ @@ -158,6 +158,61 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._send_notice.assert_not_called() + def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self): + """ + Test that when server is over MAU limit and alerting is suppressed, then + an alert message is not sent into the room + """ + self.hs.config.mau_limit_alerting = False + self._rlsn._auth.check_auth_blocking = Mock( + side_effect=ResourceLimitError( + 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER + ) + ) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + + self.assertTrue(self._send_notice.call_count == 0) + + def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): + """ + Test that when a server is disabled, that MAU limit alerting is ignored. + """ + self.hs.config.mau_limit_alerting = False + self._rlsn._auth.check_auth_blocking = Mock( + side_effect=ResourceLimitError( + 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED + ) + ) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + + # Would be better to check contents, but 2 calls == set blocking event + self.assertEqual(self._send_notice.call_count, 2) + + def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self): + """ + When the room is already in a blocked state, test that when alerting + is suppressed that the room is returned to an unblocked state. + """ + self.hs.config.mau_limit_alerting = False + self._rlsn._auth.check_auth_blocking = Mock( + side_effect=ResourceLimitError( + 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER + ) + ) + self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( + return_value=defer.succeed((True, [])) + ) + + mock_event = Mock( + type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} + ) + self._rlsn._store.get_events = Mock( + return_value=defer.succeed({"123": mock_event}) + ) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + + self._send_notice.assert_called_once() + class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): diff --git a/tests/utils.py b/tests/utils.py index 0a64f75d04..8cced4b7e8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -137,7 +137,6 @@ def default_config(name, parse=False): "limit_usage_by_mau": False, "hs_disabled": False, "hs_disabled_message": "", - "hs_disabled_limit_type": "", "max_mau_value": 50, "mau_trial_days": 0, "mau_stats_only": False, -- cgit 1.4.1 From 848cd388d96ec95b2598f1eaaf8967b8f064c08c Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 24 Oct 2019 21:13:01 -0400 Subject: delete keys when deleting backups --- synapse/storage/data_stores/main/e2e_room_keys.py | 8 +++ .../delta/56/delete_keys_from_deleted_backups.sql | 25 +++++++ tests/storage/test_e2e_room_keys.py | 76 ++++++++++++++++++++++ 3 files changed, 109 insertions(+) create mode 100644 synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql create mode 100644 tests/storage/test_e2e_room_keys.py (limited to 'tests') diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index ef88e79293..1cbbae5b63 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -321,9 +321,17 @@ class EndToEndRoomKeyStore(SQLBaseStore): def _delete_e2e_room_keys_version_txn(txn): if version is None: this_version = self._get_current_version(txn, user_id) + if this_version is None: + raise StoreError(404, "No current backup version") else: this_version = version + self._simple_delete_txn( + txn, + table="e2e_room_keys", + keyvalues={"user_id": user_id, "version": this_version}, + ) + return self._simple_update_one_txn( txn, table="e2e_room_keys_versions", diff --git a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql new file mode 100644 index 0000000000..1d2ddb1b1a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql @@ -0,0 +1,25 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* delete room keys that belong to deleted room key version, or to room key + * versions that don't exist (anymore) + */ +DELETE FROM e2e_room_keys +WHERE version NOT IN ( + SELECT version + FROM e2e_room_keys_versions + WHERE e2e_room_keys.user_id = e2e_room_keys_versions.user_id + AND e2e_room_keys_versions.deleted = 0 +); diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py new file mode 100644 index 0000000000..ef4e7ce9d6 --- /dev/null +++ b/tests/storage/test_e2e_room_keys.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from tests import unittest, utils + +# sample room_key data for use in the tests +room_key = { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": False, + "session_data": "SSBBTSBBIEZJU0gK", +} + + +class E2eRoomKeysHandlerTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs) + self.hs = None # type: synapse.server.HomeServer + self.store = None # type: synapse.storage.DataStore + + @defer.inlineCallbacks + def setUp(self): + hs = yield utils.setup_test_homeserver(self.addCleanup) + + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def test_room_keys_version_delete(self): + # test that deleting a room key backup deletes the keys + version1 = yield self.store.create_e2e_room_keys_version( + "user_id", {"algorithm": "rot13", "auth_data": {}} + ) + + yield self.store.set_e2e_room_key( + "user_id", version1, "room", "session", room_key + ) + + version2 = yield self.store.create_e2e_room_keys_version( + "user_id", {"algorithm": "rot13", "auth_data": {}} + ) + + yield self.store.set_e2e_room_key( + "user_id", version2, "room", "session", room_key + ) + + # make sure the keys were stored properly + keys = yield self.store.get_e2e_room_keys("user_id", version1) + self.assertEqual(len(keys["rooms"]), 1) + + keys = yield self.store.get_e2e_room_keys("user_id", version2) + self.assertEqual(len(keys["rooms"]), 1) + + # delete version1 + yield self.store.delete_e2e_room_keys_version("user_id", version1) + + # make sure the key from version1 is gone, and the key from version2 is + # still there + keys = yield self.store.get_e2e_room_keys("user_id", version1) + self.assertEqual(len(keys["rooms"]), 0) + + keys = yield self.store.get_e2e_room_keys("user_id", version2) + self.assertEqual(len(keys["rooms"]), 1) -- cgit 1.4.1 From 29a0bc5637e6811220f44ee727370a190b5be1ab Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 24 Oct 2019 21:43:02 -0400 Subject: remove some unnecessary lines --- tests/storage/test_e2e_room_keys.py | 5 ----- 1 file changed, 5 deletions(-) (limited to 'tests') diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index ef4e7ce9d6..6658dbda94 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -27,11 +27,6 @@ room_key = { class E2eRoomKeysHandlerTestCase(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs) - self.hs = None # type: synapse.server.HomeServer - self.store = None # type: synapse.storage.DataStore - @defer.inlineCallbacks def setUp(self): hs = yield utils.setup_test_homeserver(self.addCleanup) -- cgit 1.4.1 From 7e7a1461f64888fa71c8f06fb12f29cc3d233179 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 25 Oct 2019 10:57:37 +0100 Subject: Fix tests --- tests/handlers/test_stats.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'tests') diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index d5c8bd7612..e0075ccd32 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -607,6 +607,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): """ self.hs.config.stats_enabled = False + self.handler.stats_enabled = False u1 = self.register_user("u1", "pass") u1token = self.login("u1", "pass") @@ -618,6 +619,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertIsNone(self._get_current_stats("user", u1)) self.hs.config.stats_enabled = True + self.handler.stats_enabled = True self._perform_background_initial_update() -- cgit 1.4.1 From 4cf3a30a20c64c3939135b00b3eb5b06f273c9f9 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Fri, 25 Oct 2019 10:42:07 -0400 Subject: switch to using HomeserverTestCase --- tests/storage/test_e2e_room_keys.py | 44 +++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index 6658dbda94..9935ac59ce 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -26,46 +26,52 @@ room_key = { } -class E2eRoomKeysHandlerTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield utils.setup_test_homeserver(self.addCleanup) - +class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver("server", http_client=None) self.store = hs.get_datastore() + return hs - @defer.inlineCallbacks def test_room_keys_version_delete(self): # test that deleting a room key backup deletes the keys - version1 = yield self.store.create_e2e_room_keys_version( - "user_id", {"algorithm": "rot13", "auth_data": {}} + version1 = self.get_success( + self.store.create_e2e_room_keys_version( + "user_id", {"algorithm": "rot13", "auth_data": {}} + ) ) - yield self.store.set_e2e_room_key( - "user_id", version1, "room", "session", room_key + self.get_success( + self.store.set_e2e_room_key( + "user_id", version1, "room", "session", room_key + ) ) - version2 = yield self.store.create_e2e_room_keys_version( - "user_id", {"algorithm": "rot13", "auth_data": {}} + version2 = self.get_success( + self.store.create_e2e_room_keys_version( + "user_id", {"algorithm": "rot13", "auth_data": {}} + ) ) - yield self.store.set_e2e_room_key( - "user_id", version2, "room", "session", room_key + self.get_success( + self.store.set_e2e_room_key( + "user_id", version2, "room", "session", room_key + ) ) # make sure the keys were stored properly - keys = yield self.store.get_e2e_room_keys("user_id", version1) + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1)) self.assertEqual(len(keys["rooms"]), 1) - keys = yield self.store.get_e2e_room_keys("user_id", version2) + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2)) self.assertEqual(len(keys["rooms"]), 1) # delete version1 - yield self.store.delete_e2e_room_keys_version("user_id", version1) + self.get_success(self.store.delete_e2e_room_keys_version("user_id", version1)) # make sure the key from version1 is gone, and the key from version2 is # still there - keys = yield self.store.get_e2e_room_keys("user_id", version1) + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1)) self.assertEqual(len(keys["rooms"]), 0) - keys = yield self.store.get_e2e_room_keys("user_id", version2) + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2)) self.assertEqual(len(keys["rooms"]), 1) -- cgit 1.4.1 From 4697c0de0b0b51b7b5791f3a842b174931261a47 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Fri, 25 Oct 2019 10:47:02 -0400 Subject: remove unneeded imports --- tests/storage/test_e2e_room_keys.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index 9935ac59ce..d128fde441 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from tests import unittest, utils +from tests import unittest # sample room_key data for use in the tests room_key = { -- cgit 1.4.1 From d0d8a22c13427cce341dbb7ae1d92d2c0ae709c3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 28 Oct 2019 13:33:04 +0000 Subject: Quick fix to ensure cache descriptors always return deferreds --- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/storage/data_stores/main/roommember.py | 2 +- synapse/util/caches/descriptors.py | 4 ++-- tests/util/caches/test_descriptors.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 22491f3700..2bbdd11941 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -79,7 +79,7 @@ class BulkPushRuleEvaluator(object): dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = self._get_rules_for_room(room_id) + rules_for_room = yield self._get_rules_for_room(room_id) rules_by_user = yield rules_for_room.get_rules(event, context) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index e47ab604dd..bc04bfd7d4 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -720,7 +720,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None - cache = self._get_joined_hosts_cache(room_id) + cache = yield self._get_joined_hosts_cache(room_id) joined_hosts = yield cache.get_destinations(state_entry) return joined_hosts diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 5ac2530a6a..5a8da449b2 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -438,7 +438,7 @@ class CacheDescriptor(_CacheDescriptorBase): if isinstance(cached_result_d, ObservableDeferred): observer = cached_result_d.observe() else: - observer = cached_result_d + observer = defer.succeed(cached_result_d) except KeyError: ret = defer.maybeDeferred( @@ -618,7 +618,7 @@ class CacheListDescriptor(_CacheDescriptorBase): ) return make_deferred_yieldable(d) else: - return results + return defer.succeed(results) obj.__dict__[self.orig.__name__] = wrapped diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 5713870f48..f907903511 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -325,9 +325,9 @@ class DescriptorTestCase(unittest.TestCase): self.assertEqual(len(obj.fn.cache.cache), 3) r = obj.fn(1, 2) - self.assertEqual(r, ["spam", "eggs"]) + self.assertEqual(r.result, ["spam", "eggs"]) r = obj.fn(1, 3) - self.assertEqual(r, ["chips"]) + self.assertEqual(r.result, ["chips"]) obj.mock.assert_not_called() def test_cache_iterable_with_sync_exception(self): -- cgit 1.4.1 From 3f33879be4697666a304a472d090d4def0c671d0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Oct 2019 14:05:32 +0000 Subject: Port federation_server to async/await --- synapse/federation/federation_server.py | 205 ++++++++++++++------------------ tests/handlers/test_typing.py | 3 + 2 files changed, 90 insertions(+), 118 deletions(-) (limited to 'tests') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 5fc7c1d67b..15c1fa0a51 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -21,7 +21,6 @@ from six import iteritems from canonicaljson import json from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure @@ -86,14 +85,12 @@ class FederationServer(FederationBase): # come in waves. self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) - @defer.inlineCallbacks - @log_function - def on_backfill_request(self, origin, room_id, versions, limit): - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_backfill_request(self, origin, room_id, versions, limit): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - pdus = yield self.handler.on_backfill_request( + pdus = await self.handler.on_backfill_request( origin, room_id, versions, limit ) @@ -101,9 +98,7 @@ class FederationServer(FederationBase): return 200, res - @defer.inlineCallbacks - @log_function - def on_incoming_transaction(self, origin, transaction_data): + async def on_incoming_transaction(self, origin, transaction_data): # keep this as early as possible to make the calculated origin ts as # accurate as possible. request_time = self._clock.time_msec() @@ -118,18 +113,17 @@ class FederationServer(FederationBase): # use a linearizer to ensure that we don't process the same transaction # multiple times in parallel. with ( - yield self._transaction_linearizer.queue( + await self._transaction_linearizer.queue( (origin, transaction.transaction_id) ) ): - result = yield self._handle_incoming_transaction( + result = await self._handle_incoming_transaction( origin, transaction, request_time ) return result - @defer.inlineCallbacks - def _handle_incoming_transaction(self, origin, transaction, request_time): + async def _handle_incoming_transaction(self, origin, transaction, request_time): """ Process an incoming transaction and return the HTTP response Args: @@ -140,7 +134,7 @@ class FederationServer(FederationBase): Returns: Deferred[(int, object)]: http response code and body """ - response = yield self.transaction_actions.have_responded(origin, transaction) + response = await self.transaction_actions.have_responded(origin, transaction) if response: logger.debug( @@ -159,7 +153,7 @@ class FederationServer(FederationBase): logger.info("Transaction PDU or EDU count too large. Returning 400") response = {} - yield self.transaction_actions.set_response( + await self.transaction_actions.set_response( origin, transaction, 400, response ) return 400, response @@ -195,7 +189,7 @@ class FederationServer(FederationBase): continue try: - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) except NotFoundError: logger.info("Ignoring PDU for unknown room_id: %s", room_id) continue @@ -221,11 +215,10 @@ class FederationServer(FederationBase): # require callouts to other servers to fetch missing events), but # impose a limit to avoid going too crazy with ram/cpu. - @defer.inlineCallbacks - def process_pdus_for_room(room_id): + async def process_pdus_for_room(room_id): logger.debug("Processing PDUs for %s", room_id) try: - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) except AuthError as e: logger.warn("Ignoring PDUs for room %s from banned server", room_id) for pdu in pdus_by_room[room_id]: @@ -237,7 +230,7 @@ class FederationServer(FederationBase): event_id = pdu.event_id with nested_logging_context(event_id): try: - yield self._handle_received_pdu(origin, pdu) + await self._handle_received_pdu(origin, pdu) pdu_results[event_id] = {} except FederationError as e: logger.warn("Error handling PDU %s: %s", event_id, e) @@ -251,36 +244,33 @@ class FederationServer(FederationBase): exc_info=(f.type, f.value, f.getTracebackObject()), ) - yield concurrently_execute( + await concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT ) if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): - yield self.received_edu(origin, edu.edu_type, edu.content) + await self.received_edu(origin, edu.edu_type, edu.content) response = {"pdus": pdu_results} logger.debug("Returning: %s", str(response)) - yield self.transaction_actions.set_response(origin, transaction, 200, response) + await self.transaction_actions.set_response(origin, transaction, 200, response) return 200, response - @defer.inlineCallbacks - def received_edu(self, origin, edu_type, content): + async def received_edu(self, origin, edu_type, content): received_edus_counter.inc() - yield self.registry.on_edu(edu_type, origin, content) + await self.registry.on_edu(edu_type, origin, content) - @defer.inlineCallbacks - @log_function - def on_context_state_request(self, origin, room_id, event_id): + async def on_context_state_request(self, origin, room_id, event_id): if not event_id: raise NotImplementedError("Specify an event") origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -289,8 +279,8 @@ class FederationServer(FederationBase): # in the cache so we could return it without waiting for the linearizer # - but that's non-trivial to get right, and anyway somewhat defeats # the point of the linearizer. - with (yield self._server_linearizer.queue((origin, room_id))): - resp = yield self._state_resp_cache.wrap( + with (await self._server_linearizer.queue((origin, room_id))): + resp = await self._state_resp_cache.wrap( (room_id, event_id), self._on_context_state_request_compute, room_id, @@ -299,65 +289,58 @@ class FederationServer(FederationBase): return 200, resp - @defer.inlineCallbacks - def on_state_ids_request(self, origin, room_id, event_id): + async def on_state_ids_request(self, origin, room_id, event_id): if not event_id: raise NotImplementedError("Specify an event") origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id) - auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) + state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) + auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} - @defer.inlineCallbacks - def _on_context_state_request_compute(self, room_id, event_id): - pdus = yield self.handler.get_state_for_pdu(room_id, event_id) - auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus]) + async def _on_context_state_request_compute(self, room_id, event_id): + pdus = await self.handler.get_state_for_pdu(room_id, event_id) + auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) return { "pdus": [pdu.get_pdu_json() for pdu in pdus], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], } - @defer.inlineCallbacks - @log_function - def on_pdu_request(self, origin, event_id): - pdu = yield self.handler.get_persisted_pdu(origin, event_id) + async def on_pdu_request(self, origin, event_id): + pdu = await self.handler.get_persisted_pdu(origin, event_id) if pdu: return 200, self._transaction_from_pdus([pdu]).get_dict() else: return 404, "" - @defer.inlineCallbacks - def on_query_request(self, query_type, args): + async def on_query_request(self, query_type, args): received_queries_counter.labels(query_type).inc() - resp = yield self.registry.on_query(query_type, args) + resp = await self.registry.on_query(query_type, args) return 200, resp - @defer.inlineCallbacks - def on_make_join_request(self, origin, room_id, user_id, supported_versions): + async def on_make_join_request(self, origin, room_id, user_id, supported_versions): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) if room_version not in supported_versions: logger.warn("Room version %s not in %s", room_version, supported_versions) raise IncompatibleRoomVersionError(room_version=room_version) - pdu = yield self.handler.on_make_join_request(origin, room_id, user_id) + pdu = await self.handler.on_make_join_request(origin, room_id, user_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - @defer.inlineCallbacks - def on_invite_request(self, origin, content, room_version): + async def on_invite_request(self, origin, content, room_version): if room_version not in KNOWN_ROOM_VERSIONS: raise SynapseError( 400, @@ -369,28 +352,27 @@ class FederationServer(FederationBase): pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) - pdu = yield self._check_sigs_and_hash(room_version, pdu) - ret_pdu = yield self.handler.on_invite_request(origin, pdu) + await self.check_server_matches_acl(origin_host, pdu.room_id) + pdu = await self._check_sigs_and_hash(room_version, pdu) + ret_pdu = await self.handler.on_invite_request(origin, pdu) time_now = self._clock.time_msec() return {"event": ret_pdu.get_pdu_json(time_now)} - @defer.inlineCallbacks - def on_send_join_request(self, origin, content, room_id): + async def on_send_join_request(self, origin, content, room_id): logger.debug("on_send_join_request: content: %s", content) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) + await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) - res_pdus = yield self.handler.on_send_join_request(origin, pdu) + res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() return ( 200, @@ -402,48 +384,44 @@ class FederationServer(FederationBase): }, ) - @defer.inlineCallbacks - def on_make_leave_request(self, origin, room_id, user_id): + async def on_make_leave_request(self, origin, room_id, user_id): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) - pdu = yield self.handler.on_make_leave_request(origin, room_id, user_id) + await self.check_server_matches_acl(origin_host, room_id) + pdu = await self.handler.on_make_leave_request(origin, room_id, user_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - @defer.inlineCallbacks - def on_send_leave_request(self, origin, content, room_id): + async def on_send_leave_request(self, origin, content, room_id): logger.debug("on_send_leave_request: content: %s", content) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) + await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) - yield self.handler.on_send_leave_request(origin, pdu) + await self.handler.on_send_leave_request(origin, pdu) return 200, {} - @defer.inlineCallbacks - def on_event_auth(self, origin, room_id, event_id): - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_event_auth(self, origin, room_id, event_id): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) time_now = self._clock.time_msec() - auth_pdus = yield self.handler.on_event_auth(event_id) + auth_pdus = await self.handler.on_event_auth(event_id) res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} return 200, res - @defer.inlineCallbacks - def on_query_auth_request(self, origin, content, room_id, event_id): + async def on_query_auth_request(self, origin, content, room_id, event_id): """ Content is a dict with keys:: auth_chain (list): A list of events that give the auth chain. @@ -462,22 +440,22 @@ class FederationServer(FederationBase): Returns: Deferred: Results in `dict` with the same format as `content` """ - with (yield self._server_linearizer.queue((origin, room_id))): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) auth_chain = [ event_from_pdu_json(e, format_ver) for e in content["auth_chain"] ] - signed_auth = yield self._check_sigs_and_hash_and_fetch( + signed_auth = await self._check_sigs_and_hash_and_fetch( origin, auth_chain, outlier=True, room_version=room_version ) - ret = yield self.handler.on_query_auth( + ret = await self.handler.on_query_auth( origin, event_id, room_id, @@ -503,16 +481,14 @@ class FederationServer(FederationBase): return self.on_query_request("user_devices", user_id) @trace - @defer.inlineCallbacks - @log_function - def on_claim_client_keys(self, origin, content): + async def on_claim_client_keys(self, origin, content): query = [] for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): query.append((user_id, device_id, algorithm)) log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) - results = yield self.store.claim_e2e_one_time_keys(query) + results = await self.store.claim_e2e_one_time_keys(query) json_result = {} for user_id, device_keys in results.items(): @@ -536,14 +512,12 @@ class FederationServer(FederationBase): return {"one_time_keys": json_result} - @defer.inlineCallbacks - @log_function - def on_get_missing_events( + async def on_get_missing_events( self, origin, room_id, earliest_events, latest_events, limit ): - with (yield self._server_linearizer.queue((origin, room_id))): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," @@ -553,7 +527,7 @@ class FederationServer(FederationBase): limit, ) - missing_events = yield self.handler.on_get_missing_events( + missing_events = await self.handler.on_get_missing_events( origin, room_id, earliest_events, latest_events, limit ) @@ -586,8 +560,7 @@ class FederationServer(FederationBase): destination=None, ) - @defer.inlineCallbacks - def _handle_received_pdu(self, origin, pdu): + async def _handle_received_pdu(self, origin, pdu): """ Process a PDU received in a federation /send/ transaction. If the event is invalid, then this method throws a FederationError. @@ -640,37 +613,34 @@ class FederationServer(FederationBase): logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point - room_version = yield self.store.get_room_version(pdu.room_id) + room_version = await self.store.get_room_version(pdu.room_id) # Check signature. try: - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) except SynapseError as e: raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) - yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) + await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) def __str__(self): return "" % self.server_name - @defer.inlineCallbacks - def exchange_third_party_invite( + async def exchange_third_party_invite( self, sender_user_id, target_user_id, room_id, signed ): - ret = yield self.handler.exchange_third_party_invite( + ret = await self.handler.exchange_third_party_invite( sender_user_id, target_user_id, room_id, signed ) return ret - @defer.inlineCallbacks - def on_exchange_third_party_invite_request(self, room_id, event_dict): - ret = yield self.handler.on_exchange_third_party_invite_request( + async def on_exchange_third_party_invite_request(self, room_id, event_dict): + ret = await self.handler.on_exchange_third_party_invite_request( room_id, event_dict ) return ret - @defer.inlineCallbacks - def check_server_matches_acl(self, server_name, room_id): + async def check_server_matches_acl(self, server_name, room_id): """Check if the given server is allowed by the server ACLs in the room Args: @@ -680,13 +650,13 @@ class FederationServer(FederationBase): Raises: AuthError if the server does not match the ACL """ - state_ids = yield self.store.get_current_state_ids(room_id) + state_ids = await self.store.get_current_state_ids(room_id) acl_event_id = state_ids.get((EventTypes.ServerACL, "")) if not acl_event_id: return - acl_event = yield self.store.get_event(acl_event_id) + acl_event = await self.store.get_event(acl_event_id) if server_matches_acl_event(server_name, acl_event): return @@ -799,15 +769,14 @@ class FederationHandlerRegistry(object): self.query_handlers[query_type] = handler - @defer.inlineCallbacks - def on_edu(self, edu_type, origin, content): + async def on_edu(self, edu_type, origin, content): handler = self.edu_handlers.get(edu_type) if not handler: logger.warn("No handler registered for EDU type %s", edu_type) with start_active_span_from_edu(content, "handle_edu"): try: - yield handler(origin, content) + await handler(origin, content) except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) except Exception: diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 67f1013051..f360c8e965 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -144,6 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None + self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( + None + ) def test_started_typing_local(self): self.room_members = [U_APPLE, U_BANANA] -- cgit 1.4.1 From 326b3dace77aeb36e516ea9b04ba1baa171bcb47 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 30 Oct 2019 11:35:46 +0000 Subject: Make ObservableDeferred.observe() always return deferred. This makes it easier to use in an async/await world. Also fixes a bug where cache descriptors would occaisonally return a raw value rather than a deferred. --- synapse/util/async_helpers.py | 7 ++----- tests/storage/test__base.py | 2 +- tests/util/caches/test_descriptors.py | 4 ++-- 3 files changed, 5 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 7659eaeb42..fd75ba27ad 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -86,11 +86,8 @@ class ObservableDeferred(object): deferred.addCallbacks(callback, errback) - def observe(self): + def observe(self) -> defer.Deferred: """Observe the underlying deferred. - - Can return either a deferred if the underlying deferred is still pending - (or has failed), or the actual value. Callers may need to use maybeDeferred. """ if not self._result: d = defer.Deferred() @@ -105,7 +102,7 @@ class ObservableDeferred(object): return d else: success, res = self._result - return res if success else defer.fail(res) + return defer.succeed(res) if success else defer.fail(res) def observers(self): return self._observers diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index dd49a14524..9b81b536f5 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -197,7 +197,7 @@ class CacheDecoratorTestCase(unittest.TestCase): a.func.prefill(("foo",), ObservableDeferred(d)) - self.assertEquals(a.func("foo"), d.result) + self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0) @defer.inlineCallbacks diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index f907903511..39e360fe24 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -310,14 +310,14 @@ class DescriptorTestCase(unittest.TestCase): obj.mock.return_value = ["spam", "eggs"] r = obj.fn(1, 2) - self.assertEqual(r, ["spam", "eggs"]) + self.assertEqual(r.result, ["spam", "eggs"]) obj.mock.assert_called_once_with(1, 2) obj.mock.reset_mock() # a call with different params should call the mock again obj.mock.return_value = ["chips"] r = obj.fn(1, 3) - self.assertEqual(r, ["chips"]) + self.assertEqual(r.result, ["chips"]) obj.mock.assert_called_once_with(1, 3) obj.mock.reset_mock() -- cgit 1.4.1