From b6b57ecb4e845490fc26a537ff57df8cae1587b9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 3 Jan 2020 14:19:48 +0000 Subject: Kill off redundant SynapseRequestFactory (#6619) We already get the Site via the Channel, so there's no need for a dedicated RequestFactory: we can just use the right constructor. --- tests/server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/server.py b/tests/server.py index a554dfdd57..1644710aa0 100644 --- a/tests/server.py +++ b/tests/server.py @@ -20,6 +20,7 @@ from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.web.http import unquote from twisted.web.http_headers import Headers +from twisted.web.server import Site from synapse.http.site import SynapseRequest from synapse.util import Clock @@ -42,6 +43,7 @@ class FakeChannel(object): wire). """ + site = attr.ib(type=Site) _reactor = attr.ib() result = attr.ib(default=attr.Factory(dict)) _producer = None @@ -176,9 +178,9 @@ def make_request( content = content.encode("utf8") site = FakeSite() - channel = FakeChannel(reactor) + channel = FakeChannel(site, reactor) - req = request(site, channel) + req = request(channel) req.process = lambda: b"" req.content = BytesIO(content) req.postpath = list(map(unquote, path[1:].split(b"/"))) -- cgit 1.5.1 From 18674eebb1fa5d7445952d7e201afe33bd040523 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 6 Jan 2020 12:28:58 +0000 Subject: Workaround for error when fetching notary's own key (#6620) * Kill off redundant SynapseRequestFactory We already get the Site via the Channel, so there's no need for a dedicated RequestFactory: we can just use the right constructor. * Workaround for error when fetching notary's own key As a notary server, when we return our own keys, include all of our signing keys in verify_keys. This is a workaround for #6596. --- changelog.d/6620.misc | 1 + synapse/rest/key/v2/remote_key_resource.py | 30 ++++-- tests/rest/key/v2/test_remote_key_resource.py | 130 ++++++++++++++++++++++++++ tests/unittest.py | 11 ++- 4 files changed, 163 insertions(+), 9 deletions(-) create mode 100644 changelog.d/6620.misc create mode 100644 tests/rest/key/v2/test_remote_key_resource.py (limited to 'tests') diff --git a/changelog.d/6620.misc b/changelog.d/6620.misc new file mode 100644 index 0000000000..8bfb78fb20 --- /dev/null +++ b/changelog.d/6620.misc @@ -0,0 +1 @@ +Add a workaround for synapse raising exceptions when fetching the notary's own key from the notary. diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index e7fc3f0431..bf5e0eb844 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -15,6 +15,7 @@ import logging from canonicaljson import encode_canonical_json, json +from signedjson.key import encode_verify_key_base64 from signedjson.sign import sign_json from twisted.internet import defer @@ -216,15 +217,28 @@ class RemoteKey(DirectServeResource): if cache_misses and query_remote_on_cache_miss: yield self.fetcher.get_keys(cache_misses) yield self.query_keys(request, query, query_remote_on_cache_miss=False) - else: - signed_keys = [] - for key_json in json_results: - key_json = json.loads(key_json) + return + + signed_keys = [] + for key_json in json_results: + key_json = json.loads(key_json) + + # backwards-compatibility hack for #6596: if the requested key belongs + # to us, make sure that all of the signing keys appear in the + # "verify_keys" section. + if key_json["server_name"] == self.config.server_name: + verify_keys = key_json["verify_keys"] for signing_key in self.config.key_server_signing_keys: - key_json = sign_json(key_json, self.config.server_name, signing_key) + key_id = "%s:%s" % (signing_key.alg, signing_key.version) + verify_keys[key_id] = { + "key": encode_verify_key_base64(signing_key.verify_key) + } + + for signing_key in self.config.key_server_signing_keys: + key_json = sign_json(key_json, self.config.server_name, signing_key) - signed_keys.append(key_json) + signed_keys.append(key_json) - results = {"server_keys": signed_keys} + results = {"server_keys": signed_keys} - respond_with_json_bytes(request, 200, encode_canonical_json(results)) + respond_with_json_bytes(request, 200, encode_canonical_json(results)) diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py new file mode 100644 index 0000000000..d8246b4e78 --- /dev/null +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import urllib.parse +from io import BytesIO + +from mock import Mock + +import signedjson.key +from nacl.signing import SigningKey +from signedjson.sign import sign_json + +from twisted.web.resource import NoResource + +from synapse.http.site import SynapseRequest +from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.util.httpresourcetree import create_resource_tree + +from tests import unittest +from tests.server import FakeChannel, wait_until_result + + +class RemoteKeyResourceTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + self.http_client = Mock() + return self.setup_test_homeserver(http_client=self.http_client) + + def create_test_json_resource(self): + return create_resource_tree( + {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() + ) + + def expect_outgoing_key_request( + self, server_name: str, signing_key: SigningKey + ) -> None: + """ + Tell the mock http client to expect an outgoing GET request for the given key + """ + + def get_json(destination, path, ignore_backoff=False, **kwargs): + self.assertTrue(ignore_backoff) + self.assertEqual(destination, server_name) + key_id = "%s:%s" % (signing_key.alg, signing_key.version) + self.assertEqual( + path, "/_matrix/key/v2/server/%s" % (urllib.parse.quote(key_id),) + ) + + response = { + "server_name": server_name, + "old_verify_keys": {}, + "valid_until_ts": 200 * 1000, + "verify_keys": { + key_id: { + "key": signedjson.key.encode_verify_key_base64( + signing_key.verify_key + ) + } + }, + } + sign_json(response, server_name, signing_key) + return response + + self.http_client.get_json.side_effect = get_json + + def make_notary_request(self, server_name: str, key_id: str) -> dict: + """Send a GET request to the test server requesting the given key. + + Checks that the response is a 200 and returns the decoded json body. + """ + channel = FakeChannel(self.site, self.reactor) + req = SynapseRequest(channel) + req.content = BytesIO(b"") + req.requestReceived( + b"GET", + b"/_matrix/key/v2/query/%s/%s" + % (server_name.encode("utf-8"), key_id.encode("utf-8")), + b"1.1", + ) + wait_until_result(self.reactor, req) + self.assertEqual(channel.code, 200) + resp = channel.json_body + return resp + + def test_get_key(self): + """Fetch a remote key""" + SERVER_NAME = "remote.server" + testkey = signedjson.key.generate_signing_key("ver1") + self.expect_outgoing_key_request(SERVER_NAME, testkey) + + resp = self.make_notary_request(SERVER_NAME, "ed25519:ver1") + keys = resp["server_keys"] + self.assertEqual(len(keys), 1) + + self.assertIn("ed25519:ver1", keys[0]["verify_keys"]) + self.assertEqual(len(keys[0]["verify_keys"]), 1) + + # it should be signed by both the origin server and the notary + self.assertIn(SERVER_NAME, keys[0]["signatures"]) + self.assertIn(self.hs.hostname, keys[0]["signatures"]) + + def test_get_own_key(self): + """Fetch our own key""" + testkey = signedjson.key.generate_signing_key("ver1") + self.expect_outgoing_key_request(self.hs.hostname, testkey) + + resp = self.make_notary_request(self.hs.hostname, "ed25519:ver1") + keys = resp["server_keys"] + self.assertEqual(len(keys), 1) + + # it should be signed by both itself, and the notary signing key + sigs = keys[0]["signatures"] + self.assertEqual(len(sigs), 1) + self.assertIn(self.hs.hostname, sigs) + oursigs = sigs[self.hs.hostname] + self.assertEqual(len(oursigs), 2) + + # and both keys should be present in the verify_keys section + self.assertIn("ed25519:ver1", keys[0]["verify_keys"]) + self.assertIn("ed25519:a_lPym", keys[0]["verify_keys"]) diff --git a/tests/unittest.py b/tests/unittest.py index b30b7d1718..cbda237278 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -36,7 +36,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.federation.transport import server as federation_server from synapse.http.server import JsonResource -from synapse.http.site import SynapseRequest +from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import LoggingContext from synapse.server import HomeServer from synapse.types import Requester, UserID, create_requester @@ -210,6 +210,15 @@ class HomeserverTestCase(TestCase): # Register the resources self.resource = self.create_test_json_resource() + # create a site to wrap the resource. + self.site = SynapseSite( + logger_name="synapse.access.http.fake", + site_tag="test", + config={}, + resource=self.resource, + server_version_string="1", + ) + from tests.rest.client.v1.utils import RestHelper self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None)) -- cgit 1.5.1 From 4b36b482e0cc1a63db27534c4ea5d9608cdb6a79 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 6 Jan 2020 12:33:56 +0000 Subject: Fix exception when fetching notary server's old keys (#6625) Lift the restriction that *all* the keys used for signing v2 key responses be present in verify_keys. Fixes #6596. --- changelog.d/6625.bugfix | 1 + synapse/crypto/keyring.py | 13 ++-- tests/crypto/test_keyring.py | 139 +++++++++++++++++++++++++++++-------------- 3 files changed, 103 insertions(+), 50 deletions(-) create mode 100644 changelog.d/6625.bugfix (limited to 'tests') diff --git a/changelog.d/6625.bugfix b/changelog.d/6625.bugfix new file mode 100644 index 0000000000..a8dc5587dc --- /dev/null +++ b/changelog.d/6625.bugfix @@ -0,0 +1 @@ +Fix exception when fetching the `matrix.org:ed25519:auto` key. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 7cfad192e8..6fe5a6a26a 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -511,17 +511,18 @@ class BaseV2KeyFetcher(object): server_name = response_json["server_name"] verified = False for key_id in response_json["signatures"].get(server_name, {}): - # each of the keys used for the signature must be present in the response - # json. key = verify_keys.get(key_id) if not key: - raise KeyLookupError( - "Key response is signed by key id %s:%s but that key is not " - "present in the response" % (server_name, key_id) - ) + # the key may not be present in verify_keys if: + # * we got the key from the notary server, and: + # * the key belongs to the notary server, and: + # * the notary server is using a different key to sign notary + # responses. + continue verify_signed_json(response_json, server_name, key.verify_key) verified = True + break if not verified: raise KeyLookupError( diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 8efd39c7f7..34d5895f18 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -19,6 +19,7 @@ from mock import Mock import canonicaljson import signedjson.key import signedjson.sign +from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key from twisted.internet import defer @@ -412,34 +413,37 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): handlers=None, http_client=self.http_client, config=config ) - def test_get_keys_from_perspectives(self): - # arbitrarily advance the clock a bit - self.reactor.advance(100) - - fetcher = PerspectivesKeyFetcher(self.hs) - - SERVER_NAME = "server2" - testkey = signedjson.key.generate_signing_key("ver1") - testverifykey = signedjson.key.get_verify_key(testkey) - testverifykey_id = "ed25519:ver1" - VALID_UNTIL_TS = 200 * 1000 + def build_perspectives_response( + self, server_name: str, signing_key: SigningKey, valid_until_ts: int, + ) -> dict: + """ + Build a valid perspectives server response to a request for the given key + """ + verify_key = signedjson.key.get_verify_key(signing_key) + verifykey_id = "%s:%s" % (verify_key.alg, verify_key.version) - # valid response response = { - "server_name": SERVER_NAME, + "server_name": server_name, "old_verify_keys": {}, - "valid_until_ts": VALID_UNTIL_TS, + "valid_until_ts": valid_until_ts, "verify_keys": { - testverifykey_id: { - "key": signedjson.key.encode_verify_key_base64(testverifykey) + verifykey_id: { + "key": signedjson.key.encode_verify_key_base64(verify_key) } }, } - # the response must be signed by both the origin server and the perspectives # server. - signedjson.sign.sign_json(response, SERVER_NAME, testkey) + signedjson.sign.sign_json(response, server_name, signing_key) self.mock_perspective_server.sign_response(response) + return response + + def expect_outgoing_key_query( + self, expected_server_name: str, expected_key_id: str, response: dict + ) -> None: + """ + Tell the mock http client to expect a perspectives-server key query + """ def post_json(destination, path, data, **kwargs): self.assertEqual(destination, self.mock_perspective_server.server_name) @@ -447,11 +451,79 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # check that the request is for the expected key q = data["server_keys"] - self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"]) + self.assertEqual(list(q[expected_server_name].keys()), [expected_key_id]) return {"server_keys": [response]} self.http_client.post_json.side_effect = post_json + def test_get_keys_from_perspectives(self): + # arbitrarily advance the clock a bit + self.reactor.advance(100) + + fetcher = PerspectivesKeyFetcher(self.hs) + + SERVER_NAME = "server2" + testkey = signedjson.key.generate_signing_key("ver1") + testverifykey = signedjson.key.get_verify_key(testkey) + testverifykey_id = "ed25519:ver1" + VALID_UNTIL_TS = 200 * 1000 + + response = self.build_perspectives_response( + SERVER_NAME, testkey, VALID_UNTIL_TS, + ) + + self.expect_outgoing_key_query(SERVER_NAME, "key1", response) + + keys_to_fetch = {SERVER_NAME: {"key1": 0}} + keys = self.get_success(fetcher.get_keys(keys_to_fetch)) + self.assertIn(SERVER_NAME, keys) + k = keys[SERVER_NAME][testverifykey_id] + self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) + self.assertEqual(k.verify_key, testverifykey) + self.assertEqual(k.verify_key.alg, "ed25519") + self.assertEqual(k.verify_key.version, "ver1") + + # check that the perspectives store is correctly updated + lookup_triplet = (SERVER_NAME, testverifykey_id, None) + key_json = self.get_success( + self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + ) + res = key_json[lookup_triplet] + self.assertEqual(len(res), 1) + res = res[0] + self.assertEqual(res["key_id"], testverifykey_id) + self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) + self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) + self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) + + self.assertEqual( + bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) + ) + + def test_get_perspectives_own_key(self): + """Check that we can get the perspectives server's own keys + + This is slightly complicated by the fact that the perspectives server may + use different keys for signing notary responses. + """ + + # arbitrarily advance the clock a bit + self.reactor.advance(100) + + fetcher = PerspectivesKeyFetcher(self.hs) + + SERVER_NAME = self.mock_perspective_server.server_name + testkey = signedjson.key.generate_signing_key("ver1") + testverifykey = signedjson.key.get_verify_key(testkey) + testverifykey_id = "ed25519:ver1" + VALID_UNTIL_TS = 200 * 1000 + + response = self.build_perspectives_response( + SERVER_NAME, testkey, VALID_UNTIL_TS + ) + + self.expect_outgoing_key_query(SERVER_NAME, "key1", response) + keys_to_fetch = {SERVER_NAME: {"key1": 0}} keys = self.get_success(fetcher.get_keys(keys_to_fetch)) self.assertIn(SERVER_NAME, keys) @@ -490,35 +562,14 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): VALID_UNTIL_TS = 200 * 1000 def build_response(): - # valid response - response = { - "server_name": SERVER_NAME, - "old_verify_keys": {}, - "valid_until_ts": VALID_UNTIL_TS, - "verify_keys": { - testverifykey_id: { - "key": signedjson.key.encode_verify_key_base64(testverifykey) - } - }, - } - - # the response must be signed by both the origin server and the perspectives - # server. - signedjson.sign.sign_json(response, SERVER_NAME, testkey) - self.mock_perspective_server.sign_response(response) - return response + return self.build_perspectives_response( + SERVER_NAME, testkey, VALID_UNTIL_TS + ) def get_key_from_perspectives(response): fetcher = PerspectivesKeyFetcher(self.hs) keys_to_fetch = {SERVER_NAME: {"key1": 0}} - - def post_json(destination, path, data, **kwargs): - self.assertEqual(destination, self.mock_perspective_server.server_name) - self.assertEqual(path, "/_matrix/key/v2/query") - return {"server_keys": [response]} - - self.http_client.post_json.side_effect = post_json - + self.expect_outgoing_key_query(SERVER_NAME, "key1", response) return self.get_success(fetcher.get_keys(keys_to_fetch)) # start with a valid response so we can check we are testing the right thing -- cgit 1.5.1 From 5a047816434e2ce2df8b80eb63a49c17dc3085fb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 3 Jan 2020 15:31:09 +0000 Subject: rename get_prev_events_for_room to get_prev_events_and_hashes_for_room ... to make way for a new method which just returns the event ids --- synapse/handlers/message.py | 6 ++++-- synapse/handlers/room_member.py | 4 +++- synapse/storage/data_stores/main/event_federation.py | 5 +++-- tests/storage/test_event_federation.py | 4 ++-- 4 files changed, 12 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4ad752205f..2695975a16 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -740,7 +740,7 @@ class EventCreationHandler(object): % (len(prev_events_and_hashes),) ) else: - prev_events_and_hashes = yield self.store.get_prev_events_for_room( + prev_events_and_hashes = yield self.store.get_prev_events_and_hashes_for_room( builder.room_id ) @@ -1042,7 +1042,9 @@ class EventCreationHandler(object): # For each room we need to find a joined member we can use to send # the dummy event with. - prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) + prev_events_and_hashes = yield self.store.get_prev_events_and_hashes_for_room( + room_id + ) latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 44c5e3239c..91bb34cd55 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -370,7 +370,9 @@ class RoomMemberHandler(object): if block_invite: raise SynapseError(403, "Invites have been disabled on this server") - prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) + prev_events_and_hashes = yield self.store.get_prev_events_and_hashes_for_room( + room_id + ) latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) current_state_ids = yield self.state_handler.get_current_state_ids( diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 1f517e8fad..266fc9715f 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -149,9 +149,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) @defer.inlineCallbacks - def get_prev_events_for_room(self, room_id): + def get_prev_events_and_hashes_for_room(self, room_id): """ - Gets a subset of the current forward extremities in the given room. + Gets a subset of the current forward extremities in the given room, + along with their depths and hashes. Limits the result to 10 extremities, so that we can avoid creating events which refer to hundreds of prev_events. diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index eadfb90a22..3a68bf3274 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -26,7 +26,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): self.store = hs.get_datastore() @defer.inlineCallbacks - def test_get_prev_events_for_room(self): + def test_get_prev_events_and_hashes_for_room(self): room_id = "@ROOM:local" # add a bunch of events and hashes to act as forward extremities @@ -64,7 +64,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): yield self.store.db.runInteraction("insert", insert_event, i) # this should get the last five and five others - r = yield self.store.get_prev_events_for_room(room_id) + r = yield self.store.get_prev_events_and_hashes_for_room(room_id) self.assertEqual(10, len(r)) for i in range(0, 5): el = r[i] -- cgit 1.5.1 From 3bef62488e5cff4dfb33454f2f2e18cc928f319b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 3 Jan 2020 16:19:55 +0000 Subject: Remove unused hashes and depths from create_event params --- synapse/handlers/message.py | 21 +++++---------------- synapse/handlers/room_member.py | 8 +++++++- tests/unittest.py | 6 +----- 3 files changed, 13 insertions(+), 22 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5415b0c9ee..8ea3aca2f4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -422,7 +422,7 @@ class EventCreationHandler(object): event_dict, token_id=None, txn_id=None, - prev_events_and_hashes=None, + prev_event_ids: Optional[Collection[str]] = None, require_consent=True, ): """ @@ -439,10 +439,9 @@ class EventCreationHandler(object): token_id (str) txn_id (str) - prev_events_and_hashes (list[(str, dict[str, str], int)]|None): + prev_event_ids: the forward extremities to use as the prev_events for the - new event. For each event, a tuple of (event_id, hashes, depth) - where *hashes* is a map from algorithm to hash. + new event. If None, they will be requested from the database. @@ -497,12 +496,6 @@ class EventCreationHandler(object): if txn_id is not None: builder.internal_metadata.txn_id = txn_id - prev_event_ids = ( - None - if prev_events_and_hashes is None - else [event_id for event_id, _, _ in prev_events_and_hashes] - ) - event, context = yield self.create_new_client_event( builder=builder, requester=requester, prev_event_ids=prev_event_ids, ) @@ -1038,11 +1031,7 @@ class EventCreationHandler(object): # For each room we need to find a joined member we can use to send # the dummy event with. - prev_events_and_hashes = yield self.store.get_prev_events_and_hashes_for_room( - room_id - ) - - latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) + latest_event_ids = yield self.store.get_prev_events_for_room(room_id) members = yield self.state.get_current_users_in_room( room_id, latest_event_ids=latest_event_ids @@ -1061,7 +1050,7 @@ class EventCreationHandler(object): "room_id": room_id, "sender": user_id, }, - prev_events_and_hashes=prev_events_and_hashes, + prev_event_ids=latest_event_ids, ) event.internal_metadata.proactively_send = False diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 91bb34cd55..d550ba8ab4 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -164,6 +164,12 @@ class RoomMemberHandler(object): if requester.is_guest: content["kind"] = "guest" + prev_event_ids = ( + None + if prev_events_and_hashes is None + else [event_id for event_id, _, _ in prev_events_and_hashes] + ) + event, context = yield self.event_creation_handler.create_event( requester, { @@ -177,7 +183,7 @@ class RoomMemberHandler(object): }, token_id=requester.access_token_id, txn_id=txn_id, - prev_events_and_hashes=prev_events_and_hashes, + prev_event_ids=prev_event_ids, require_consent=require_consent, ) diff --git a/tests/unittest.py b/tests/unittest.py index b30b7d1718..07b50c0ccd 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -522,10 +522,6 @@ class HomeserverTestCase(TestCase): secrets = self.hs.get_secrets() requester = Requester(user, None, False, None, None) - prev_events_and_hashes = None - if prev_event_ids: - prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids] - event, context = self.get_success( event_creator.create_event( requester, @@ -535,7 +531,7 @@ class HomeserverTestCase(TestCase): "sender": user.to_string(), "content": {"body": secrets.token_hex(), "msgtype": "m.text"}, }, - prev_events_and_hashes=prev_events_and_hashes, + prev_event_ids=prev_event_ids, ) ) -- cgit 1.5.1 From dc41fbf0dda981df117d8cf1938e023a38836cda Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 3 Jan 2020 16:30:51 +0000 Subject: Remove unused get_prev_events_and_hashes_for_room --- .../storage/data_stores/main/event_federation.py | 30 ---------------------- tests/storage/test_event_federation.py | 19 +++++--------- 2 files changed, 6 insertions(+), 43 deletions(-) (limited to 'tests') diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 88e6489576..32e76621a7 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -14,7 +14,6 @@ # limitations under the License. import itertools import logging -import random from six.moves import range from six.moves.queue import Empty, PriorityQueue @@ -148,35 +147,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas retcol="event_id", ) - @defer.inlineCallbacks - def get_prev_events_and_hashes_for_room(self, room_id): - """ - Gets a subset of the current forward extremities in the given room, - along with their depths and hashes. - - Limits the result to 10 extremities, so that we can avoid creating - events which refer to hundreds of prev_events. - - Args: - room_id (str): room_id - - Returns: - Deferred[list[(str, dict[str, str], int)]] - for each event, a tuple of (event_id, hashes, depth) - where *hashes* is a map from algorithm to hash. - """ - res = yield self.get_latest_event_ids_and_hashes_in_room(room_id) - if len(res) > 10: - # Sort by reverse depth, so we point to the most recent. - res.sort(key=lambda a: -a[2]) - - # we use half of the limit for the actual most recent events, and - # the other half to randomly point to some of the older events, to - # make sure that we don't completely ignore the older events. - res = res[0:5] + random.sample(res[5:], 5) - - return res - def get_prev_events_for_room(self, room_id: str): """ Gets a subset of the current forward extremities in the given room. diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 3a68bf3274..a331517f4d 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -26,7 +26,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): self.store = hs.get_datastore() @defer.inlineCallbacks - def test_get_prev_events_and_hashes_for_room(self): + def test_get_prev_events_for_room(self): room_id = "@ROOM:local" # add a bunch of events and hashes to act as forward extremities @@ -60,21 +60,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): (event_id, bytearray(b"ffff")), ) - for i in range(0, 11): + for i in range(0, 20): yield self.store.db.runInteraction("insert", insert_event, i) - # this should get the last five and five others - r = yield self.store.get_prev_events_and_hashes_for_room(room_id) + # this should get the last ten + r = yield self.store.get_prev_events_for_room(room_id) self.assertEqual(10, len(r)) - for i in range(0, 5): - el = r[i] - depth = el[2] - self.assertEqual(10 - i, depth) - - for i in range(5, 5): - el = r[i] - depth = el[2] - self.assertLessEqual(5, depth) + for i in range(0, 10): + self.assertEqual("$event_%i:local" % (19 - i), r[i]) @defer.inlineCallbacks def test_get_rooms_with_many_extremities(self): -- cgit 1.5.1 From d20c3465441cd64ba3a1e84ee399bbadc0997bdf Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 7 Jan 2020 14:09:07 +0000 Subject: port BackgroundUpdateTestCase to HomeserverTestCase (#6653) --- changelog.d/6653.misc | 1 + tests/storage/test_background_update.py | 72 +++++++++++++++++---------------- 2 files changed, 38 insertions(+), 35 deletions(-) create mode 100644 changelog.d/6653.misc (limited to 'tests') diff --git a/changelog.d/6653.misc b/changelog.d/6653.misc new file mode 100644 index 0000000000..fbe7c0e7db --- /dev/null +++ b/changelog.d/6653.misc @@ -0,0 +1 @@ +Port core background update routines to async/await. diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index aec76f4ab1..ae14fb407d 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -2,44 +2,37 @@ from mock import Mock from twisted.internet import defer +from synapse.storage.background_updates import BackgroundUpdater + from tests import unittest -from tests.utils import setup_test_homeserver -class BackgroundUpdateTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) - self.store = hs.get_datastore() - self.clock = hs.get_clock() +class BackgroundUpdateTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater + # the base test class should have run the real bg updates for us + self.assertTrue(self.updates.has_completed_background_updates()) self.update_handler = Mock() - - yield self.store.db.updates.register_background_update_handler( + self.updates.register_background_update_handler( "test_update", self.update_handler ) - # run the real background updates, to get them out the way - # (perhaps we should run them as part of the test HS setup, since we - # run all of the other schema setup stuff there?) - while True: - res = yield self.store.db.updates.do_next_background_update(1000) - if res is None: - break - - @defer.inlineCallbacks def test_do_background_update(self): - desired_count = 1000 + # the time we claim each update takes duration_ms = 42 + # the target runtime for each bg update + target_background_update_duration_ms = 50000 + # first step: make a bit of progress @defer.inlineCallbacks def update(progress, count): - self.clock.advance_time_msec(count * duration_ms) + yield self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} - yield self.store.db.runInteraction( + yield self.hs.get_datastore().db.runInteraction( "update_progress", - self.store.db.updates._background_update_progress_txn, + self.updates._background_update_progress_txn, "test_update", progress, ) @@ -47,37 +40,46 @@ class BackgroundUpdateTestCase(unittest.TestCase): self.update_handler.side_effect = update - yield self.store.db.updates.start_background_update( - "test_update", {"my_key": 1} + self.get_success( + self.updates.start_background_update("test_update", {"my_key": 1}) ) - self.update_handler.reset_mock() - result = yield self.store.db.updates.do_next_background_update( - duration_ms * desired_count + res = self.get_success( + self.updates.do_next_background_update( + target_background_update_duration_ms + ), + by=0.1, ) - self.assertIsNotNone(result) + self.assertIsNotNone(res) + + # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( - {"my_key": 1}, self.store.db.updates.DEFAULT_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE ) # second step: complete the update + # we should now get run with a much bigger number of items to update @defer.inlineCallbacks def update(progress, count): - yield self.store.db.updates._end_background_update("test_update") + self.assertEqual(progress, {"my_key": 2}) + self.assertAlmostEqual( + count, target_background_update_duration_ms / duration_ms, places=0, + ) + yield self.updates._end_background_update("test_update") return count self.update_handler.side_effect = update self.update_handler.reset_mock() - result = yield self.store.db.updates.do_next_background_update( - duration_ms * desired_count + result = self.get_success( + self.updates.do_next_background_update(target_background_update_duration_ms) ) self.assertIsNotNone(result) - self.update_handler.assert_called_once_with({"my_key": 2}, desired_count) + self.update_handler.assert_called_once() # third step: we don't expect to be called any more self.update_handler.reset_mock() - result = yield self.store.db.updates.do_next_background_update( - duration_ms * desired_count + result = self.get_success( + self.updates.do_next_background_update(target_background_update_duration_ms) ) self.assertIsNone(result) self.assertFalse(self.update_handler.called) -- cgit 1.5.1