From 39230d217104f3cd7aba9065dc478f935ce1e614 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 24 Mar 2020 14:45:33 +0000 Subject: Clean up some LoggingContext stuff (#7120) * Pull Sentinel out of LoggingContext ... and drop a few unnecessary references to it * Factor out LoggingContext.current_context move `current_context` and `set_context` out to top-level functions. Mostly this means that I can more easily trace what's actually referring to LoggingContext, but I think it's generally neater. * move copy-to-parent into `stop` this really just makes `start` and `stop` more symetric. It also means that it behaves correctly if you manually `set_log_context` rather than using the context manager. * Replace `LoggingContext.alive` with `finished` Turn `alive` into `finished` and make it a bit better defined. --- tests/unittest.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'tests/unittest.py') diff --git a/tests/unittest.py b/tests/unittest.py index 8816a4d152..439174dbfc 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -38,7 +38,11 @@ from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.federation.transport import server as federation_server from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite -from synapse.logging.context import LoggingContext +from synapse.logging.context import ( + SENTINEL_CONTEXT, + current_context, + set_current_context, +) from synapse.server import HomeServer from synapse.types import Requester, UserID, create_requester from synapse.util.ratelimitutils import FederationRateLimiter @@ -97,10 +101,10 @@ class TestCase(unittest.TestCase): def setUp(orig): # if we're not starting in the sentinel logcontext, then to be honest # all future bets are off. - if LoggingContext.current_context() is not LoggingContext.sentinel: + if current_context(): self.fail( "Test starting with non-sentinel logging context %s" - % (LoggingContext.current_context(),) + % (current_context(),) ) old_level = logging.getLogger().level @@ -122,7 +126,7 @@ class TestCase(unittest.TestCase): # force a GC to workaround problems with deferreds leaking logcontexts when # they are GCed (see the logcontext docs) gc.collect() - LoggingContext.set_current_context(LoggingContext.sentinel) + set_current_context(SENTINEL_CONTEXT) return ret -- cgit 1.5.1 From 28d9d6e8a9d6a6d5162de41cada1b6d6d4b0f941 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 24 Mar 2020 18:33:49 +0000 Subject: Remove spurious "name" parameter to `default_config` this is never set to anything other than "test", and is a source of unnecessary boilerplate. --- tests/app/test_frontend_proxy.py | 4 ++-- tests/app/test_openid_listener.py | 4 ++-- tests/federation/test_complexity.py | 4 ++-- tests/handlers/test_register.py | 2 +- tests/rest/client/v2_alpha/test_register.py | 4 ++-- tests/rest/key/v2/test_remote_key_resource.py | 4 ++-- tests/server_notices/test_resource_limits_server_notices.py | 2 +- tests/test_terms_auth.py | 4 ++-- tests/unittest.py | 7 ++----- 9 files changed, 16 insertions(+), 19 deletions(-) (limited to 'tests/unittest.py') diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index d3feafa1b7..be20a89682 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -27,8 +27,8 @@ class FrontendProxyTests(HomeserverTestCase): return hs - def default_config(self, name="test"): - c = super().default_config(name) + def default_config(self): + c = super().default_config() c["worker_app"] = "synapse.app.frontend_proxy" return c diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 89fcc3889a..7364f9f1ec 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -29,8 +29,8 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): ) return hs - def default_config(self, name="test"): - conf = super().default_config(name) + def default_config(self): + conf = super().default_config() # we're using FederationReaderServer, which uses a SlavedStore, so we # have to tell the FederationHandler not to try to access stuff that is only # in the primary store. diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 24fa8dbb45..94980733c4 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -33,8 +33,8 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): login.register_servlets, ] - def default_config(self, name="test"): - config = super().default_config(name=name) + def default_config(self): + config = super().default_config() config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05} return config diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e2915eb7b1..e7b638dbfe 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -34,7 +34,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): """ Tests the RegistrationHandler. """ def make_homeserver(self, reactor, clock): - hs_config = self.default_config("test") + hs_config = self.default_config() # some of the tests rely on us having a user consent version hs_config["user_consent"] = { diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index d0c997e385..b6ed06e02d 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -36,8 +36,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): servlets = [register.register_servlets] url = b"/_matrix/client/r0/register" - def default_config(self, name="test"): - config = super().default_config(name) + def default_config(self): + config = super().default_config() config["allow_guest_access"] = True return config diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 6776a56cad..99eb477149 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -143,8 +143,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): endpoint, to check that the two implementations are compatible. """ - def default_config(self, *args, **kwargs): - config = super().default_config(*args, **kwargs) + def default_config(self): + config = super().default_config() # replace the signing key with our own self.hs_signing_key = signedjson.key.generate_signing_key("kssk") diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index eb540e34f6..0d27b92a86 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -28,7 +28,7 @@ from tests import unittest class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - hs_config = self.default_config("test") + hs_config = self.default_config() hs_config["server_notices"] = { "system_mxid_localpart": "server", "system_mxid_display_name": "test display name", diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 5ec5d2b358..81d796f3f3 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -28,8 +28,8 @@ from tests import unittest class TermsTestCase(unittest.HomeserverTestCase): servlets = [register_servlets] - def default_config(self, name="test"): - config = super().default_config(name) + def default_config(self): + config = super().default_config() config.update( { "public_baseurl": "https://example.org/", diff --git a/tests/unittest.py b/tests/unittest.py index 8816a4d152..23b59bea22 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -311,14 +311,11 @@ class HomeserverTestCase(TestCase): return resource - def default_config(self, name="test"): + def default_config(self): """ Get a default HomeServer config dict. - - Args: - name (str): The homeserver name/domain. """ - config = default_config(name) + config = default_config("test") # apply any additional config which was specified via the override_config # decorator. -- cgit 1.5.1 From 665630fcaab8f09e83ff77f35d5244a718e20701 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 27 Mar 2020 11:39:43 +0000 Subject: Add tests for outbound device pokes --- changelog.d/7157.misc | 1 + tests/federation/test_federation_sender.py | 303 ++++++++++++++++++++++++++++- tests/unittest.py | 1 + 3 files changed, 302 insertions(+), 3 deletions(-) create mode 100644 changelog.d/7157.misc (limited to 'tests/unittest.py') diff --git a/changelog.d/7157.misc b/changelog.d/7157.misc new file mode 100644 index 0000000000..0eb1128c7a --- /dev/null +++ b/changelog.d/7157.misc @@ -0,0 +1 @@ +Add tests for outbound device pokes. diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index d456267b87..7763b12159 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -12,19 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from mock import Mock +from signedjson import key, sign +from signedjson.types import BaseKey, SigningKey + from twisted.internet import defer -from synapse.types import ReadReceipt +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.types import JsonDict, ReadReceipt from tests.unittest import HomeserverTestCase, override_config -class FederationSenderTestCases(HomeserverTestCase): +class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): - return super(FederationSenderTestCases, self).setup_test_homeserver( + return self.setup_test_homeserver( state_handler=Mock(spec=["get_current_hosts_in_room"]), federation_transport_client=Mock(spec=["send_transaction"]), ) @@ -147,3 +153,294 @@ class FederationSenderTestCases(HomeserverTestCase): } ], ) + + +class FederationSenderDevicesTestCases(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + state_handler=Mock(spec=["get_current_hosts_in_room"]), + federation_transport_client=Mock(spec=["send_transaction"]), + ) + + def default_config(self): + c = super().default_config() + c["send_federation"] = True + return c + + def prepare(self, reactor, clock, hs): + # stub out get_current_hosts_in_room + mock_state_handler = hs.get_state_handler() + mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] + + # stub out get_users_who_share_room_with_user so that it claims that + # `@user2:host2` is in the room + def get_users_who_share_room_with_user(user_id): + return defer.succeed({"@user2:host2"}) + + hs.get_datastore().get_users_who_share_room_with_user = ( + get_users_who_share_room_with_user + ) + + # whenever send_transaction is called, record the edu data + self.edus = [] + self.hs.get_federation_transport_client().send_transaction.side_effect = ( + self.record_transaction + ) + + def record_transaction(self, txn, json_cb): + data = json_cb() + self.edus.extend(data["edus"]) + return defer.succeed({}) + + def test_send_device_updates(self): + """Basic case: each device update should result in an EDU""" + # create a device + u1 = self.register_user("user", "pass") + self.login(u1, "pass", device_id="D1") + + # expect one edu + self.assertEqual(len(self.edus), 1) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + + # a second call should produce no new device EDUs + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + self.assertEqual(self.edus, []) + + # a second device + self.login("user", "pass", device_id="D2") + + self.assertEqual(len(self.edus), 1) + self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + + def test_upload_signatures(self): + """Uploading signatures on some devices should produce updates for that user""" + + e2e_handler = self.hs.get_e2e_keys_handler() + + # register two devices + u1 = self.register_user("user", "pass") + self.login(u1, "pass", device_id="D1") + self.login(u1, "pass", device_id="D2") + + # expect two edus + self.assertEqual(len(self.edus), 2) + stream_id = None + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + + # upload signing keys for each device + device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1") + device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2") + + # expect two more edus + self.assertEqual(len(self.edus), 2) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + + # upload master key and self-signing key + master_signing_key = generate_self_id_key() + master_key = { + "user_id": u1, + "usage": ["master"], + "keys": {key_id(master_signing_key): encode_pubkey(master_signing_key)}, + } + + # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8 + selfsigning_signing_key = generate_self_id_key() + selfsigning_key = { + "user_id": u1, + "usage": ["self_signing"], + "keys": { + key_id(selfsigning_signing_key): encode_pubkey(selfsigning_signing_key) + }, + } + sign.sign_json(selfsigning_key, u1, master_signing_key) + + cross_signing_keys = { + "master_key": master_key, + "self_signing_key": selfsigning_key, + } + + self.get_success( + e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys) + ) + + # expect signing key update edu + self.assertEqual(len(self.edus), 1) + self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update") + + # sign the devices + d1_json = build_device_dict(u1, "D1", device1_signing_key) + sign.sign_json(d1_json, u1, selfsigning_signing_key) + d2_json = build_device_dict(u1, "D2", device2_signing_key) + sign.sign_json(d2_json, u1, selfsigning_signing_key) + + ret = self.get_success( + e2e_handler.upload_signatures_for_device_keys( + u1, {u1: {"D1": d1_json, "D2": d2_json}}, + ) + ) + self.assertEqual(ret["failures"], {}) + + # expect two edus, in one or two transactions. We don't know what order the + # devices will be updated. + self.assertEqual(len(self.edus), 2) + stream_id = None # FIXME: there is a discontinuity in the stream IDs: see #7142 + for edu in self.edus: + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + if stream_id is not None: + self.assertEqual(c["prev_id"], [stream_id]) + stream_id = c["stream_id"] + devices = {edu["content"]["device_id"] for edu in self.edus} + self.assertEqual({"D1", "D2"}, devices) + + def test_delete_devices(self): + """If devices are deleted, that should result in EDUs too""" + + # create devices + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # expect three edus + self.assertEqual(len(self.edus), 3) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id) + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + # expect three edus, in an unknown order + self.assertEqual(len(self.edus), 3) + for edu in self.edus: + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + self.assertGreaterEqual( + c.items(), + {"user_id": u1, "prev_id": [stream_id], "deleted": True}.items(), + ) + stream_id = c["stream_id"] + devices = {edu["content"]["device_id"] for edu in self.edus} + self.assertEqual({"D1", "D2", "D3"}, devices) + + def test_unreachable_server(self): + """If the destination server is unreachable, all the updates should get sent on + recovery + """ + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + # create devices + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 4) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # for each device, there should be a single update + self.assertEqual(len(self.edus), 3) + stream_id = None + for edu in self.edus: + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else []) + stream_id = c["stream_id"] + devices = {edu["content"]["device_id"] for edu in self.edus} + self.assertEqual({"D1", "D2", "D3"}, devices) + + def check_device_update_edu( + self, + edu: JsonDict, + user_id: str, + device_id: str, + prev_stream_id: Optional[int], + ) -> int: + """Check that the given EDU is an update for the given device + Returns the stream_id. + """ + self.assertEqual(edu["edu_type"], "m.device_list_update") + content = edu["content"] + + expected = { + "user_id": user_id, + "device_id": device_id, + "prev_id": [prev_stream_id] if prev_stream_id is not None else [], + } + + self.assertLessEqual(expected.items(), content.items()) + return content["stream_id"] + + def check_signing_key_update_txn(self, txn: JsonDict,) -> None: + """Check that the txn has an EDU with a signing key update. + """ + edus = txn["edus"] + self.assertEqual(len(edus), 1) + + def generate_and_upload_device_signing_key( + self, user_id: str, device_id: str + ) -> SigningKey: + """Generate a signing keypair for the given device, and upload it""" + sk = key.generate_signing_key(device_id) + + device_dict = build_device_dict(user_id, device_id, sk) + + self.get_success( + self.hs.get_e2e_keys_handler().upload_keys_for_user( + user_id, device_id, {"device_keys": device_dict}, + ) + ) + return sk + + +def generate_self_id_key() -> SigningKey: + """generate a signing key whose version is its public key + + ... as used by the cross-signing-keys. + """ + k = key.generate_signing_key("x") + k.version = encode_pubkey(k) + return k + + +def key_id(k: BaseKey) -> str: + return "%s:%s" % (k.alg, k.version) + + +def encode_pubkey(sk: SigningKey) -> str: + """Encode the public key corresponding to the given signing key as base64""" + return key.encode_verify_key_base64(key.get_verify_key(sk)) + + +def build_device_dict(user_id: str, device_id: str, sk: SigningKey): + """Build a dict representing the given device""" + return { + "user_id": user_id, + "device_id": device_id, + "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "keys": { + "curve25519:" + device_id: "curve25519+key", + key_id(sk): encode_pubkey(sk), + }, + } diff --git a/tests/unittest.py b/tests/unittest.py index 23b59bea22..3d57b77a5d 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -490,6 +490,7 @@ class HomeserverTestCase(TestCase): "password": password, "admin": admin, "mac": want_mac, + "inhibit_login": True, } ) request, channel = self.make_request( -- cgit 1.5.1 From 51f4d52cb4663a056372d779b78488aeae45f554 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 31 Mar 2020 17:27:56 +0100 Subject: Set a logging context while running the bg updates This mostly just reduces the amount of "running from sentinel context" spam during unittest setup. --- tests/unittest.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'tests/unittest.py') diff --git a/tests/unittest.py b/tests/unittest.py index d0406ca2fd..27af5228fe 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -40,6 +40,7 @@ from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import ( SENTINEL_CONTEXT, + LoggingContext, current_context, set_current_context, ) @@ -419,15 +420,17 @@ class HomeserverTestCase(TestCase): config_obj.parse_config_dict(config, "", "") kwargs["config"] = config_obj + async def run_bg_updates(): + with LoggingContext("run_bg_updates", request="run_bg_updates-1"): + while not await stor.db.updates.has_completed_background_updates(): + await stor.db.updates.do_next_background_update(1) + hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() # Run the database background updates, when running against "master". if hs.__class__.__name__ == "TestHomeServer": - while not self.get_success( - stor.db.updates.has_completed_background_updates() - ): - self.get_success(stor.db.updates.do_next_background_update(1)) + self.get_success(run_bg_updates()) return hs -- cgit 1.5.1 From c2e1a2110fbe9ead26b4ecbb1afd504ed035a04d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 29 Apr 2020 12:30:36 +0100 Subject: Fix limit logic for EventsStream (#7358) * Factor out functions for injecting events into database I want to add some more flexibility to the tools for injecting events into the database, and I don't want to clutter up HomeserverTestCase with them, so let's factor them out to a new file. * Rework TestReplicationDataHandler This wasn't very easy to work with: the mock wrapping was largely superfluous, and it's useful to be able to inspect the received rows, and clear out the received list. * Fix AssertionErrors being thrown by EventsStream Part of the problem was that there was an off-by-one error in the assertion, but also the limit logic was too simple. Fix it all up and add some tests. --- changelog.d/7358.bugfix | 1 + synapse/replication/tcp/handler.py | 4 +- synapse/replication/tcp/streams/events.py | 22 +- synapse/server.pyi | 5 + synapse/storage/data_stores/main/events_worker.py | 64 +++- tests/replication/tcp/streams/_base.py | 41 ++- tests/replication/tcp/streams/test_events.py | 417 ++++++++++++++++++++++ tests/replication/tcp/streams/test_receipts.py | 10 +- tests/replication/tcp/streams/test_typing.py | 11 +- tests/rest/client/v1/utils.py | 2 +- tests/test_utils/__init__.py | 20 ++ tests/test_utils/event_injection.py | 96 +++++ tests/unittest.py | 30 +- tox.ini | 2 + 14 files changed, 658 insertions(+), 67 deletions(-) create mode 100644 changelog.d/7358.bugfix create mode 100644 tests/replication/tcp/streams/test_events.py create mode 100644 tests/test_utils/event_injection.py (limited to 'tests/unittest.py') diff --git a/changelog.d/7358.bugfix b/changelog.d/7358.bugfix new file mode 100644 index 0000000000..f49c600173 --- /dev/null +++ b/changelog.d/7358.bugfix @@ -0,0 +1 @@ +Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 0db5a3a24d..3a8c7c7e2d 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -87,7 +87,9 @@ class ReplicationCommandHandler: stream.NAME: stream(hs) for stream in STREAMS_MAP.values() } # type: Dict[str, Stream] - self._position_linearizer = Linearizer("replication_position") + self._position_linearizer = Linearizer( + "replication_position", clock=self._clock + ) # Map of stream to batched updates. See RdataCommand for info on how # batching works. diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index aa50492569..52df81b1bd 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -170,22 +170,16 @@ class EventsStream(Stream): limited = False upper_limit = current_token - # next up is the state delta table - - state_rows = await self._store.get_all_updated_current_state_deltas( + # next up is the state delta table. + ( + state_rows, + upper_limit, + state_rows_limited, + ) = await self._store.get_all_updated_current_state_deltas( from_token, upper_limit, target_row_count - ) # type: List[Tuple] - - # again, if we've hit the limit there, we'll need to limit the other sources - assert len(state_rows) < target_row_count - if len(state_rows) == target_row_count: - assert state_rows[-1][0] <= upper_limit - upper_limit = state_rows[-1][0] - limited = True + ) - # FIXME: is it a given that there is only one row per stream_id in the - # state_deltas table (so that we can be sure that we have got all of the - # rows for upper_limit)? + limited = limited or state_rows_limited # finally, fetch the ex-outliers rows. We assume there are few enough of these # not to bother with the limit. diff --git a/synapse/server.pyi b/synapse/server.pyi index f1a5717028..fc5886f762 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage +from synapse.events.builder import EventBuilderFactory class HomeServer(object): @property @@ -121,3 +122,7 @@ class HomeServer(object): pass def get_instance_id(self) -> str: pass + def get_event_builder_factory(self) -> EventBuilderFactory: + pass + def get_storage(self) -> synapse.storage.Storage: + pass diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index ce8be72bfe..73df6b33ba 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -19,7 +19,7 @@ import itertools import logging import threading from collections import namedtuple -from typing import List, Optional +from typing import List, Optional, Tuple from canonicaljson import json from constantly import NamedConstant, Names @@ -1084,7 +1084,28 @@ class EventsWorkerStore(SQLBaseStore): "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows ) - def get_all_updated_current_state_deltas(self, from_token, to_token, limit): + async def get_all_updated_current_state_deltas( + self, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple], int, bool]: + """Fetch updates from current_state_delta_stream + + Args: + from_token: The previous stream token. Updates from this stream id will + be excluded. + + to_token: The current stream token (ie the upper limit). Updates up to this + stream id will be included (modulo the 'limit' param) + + target_row_count: The number of rows to try to return. If more rows are + available, we will set 'limited' in the result. In the event of a large + batch, we may return more rows than this. + Returns: + A triplet `(updates, new_last_token, limited)`, where: + * `updates` is a list of database tuples. + * `new_last_token` is the new position in stream. + * `limited` is whether there are more updates to fetch. + """ + def get_all_updated_current_state_deltas_txn(txn): sql = """ SELECT stream_id, room_id, type, state_key, event_id @@ -1092,10 +1113,45 @@ class EventsWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ - txn.execute(sql, (from_token, to_token, limit)) + txn.execute(sql, (from_token, to_token, target_row_count)) return txn.fetchall() - return self.db.runInteraction( + def get_deltas_for_stream_id_txn(txn, stream_id): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE stream_id = ? + """ + txn.execute(sql, [stream_id]) + return txn.fetchall() + + # we need to make sure that, for every stream id in the results, we get *all* + # the rows with that stream id. + + rows = await self.db.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, + ) # type: List[Tuple] + + # if we've got fewer rows than the limit, we're good + if len(rows) < target_row_count: + return rows, to_token, False + + # we hit the limit, so reduce the upper limit so that we exclude the stream id + # of the last row in the result. + assert rows[-1][0] <= to_token + to_token = rows[-1][0] - 1 + + # search backwards through the list for the point to truncate + for idx in range(len(rows) - 1, 0, -1): + if rows[idx - 1][0] <= to_token: + return rows[:idx], to_token, True + + # bother. We didn't get a full set of changes for even a single + # stream id. let's run the query again, without a row limit, but for + # just one stream id. + to_token += 1 + rows = await self.db.runInteraction( + "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token ) + return rows, to_token, True diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 82f15c64e0..83e16cfe3d 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -12,10 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -from typing import Optional -from mock import Mock +import logging +from typing import Any, Dict, List, Optional, Tuple import attr @@ -25,6 +24,7 @@ from twisted.web.http import HTTPChannel from synapse.app.generic_worker import GenericWorkerServer from synapse.http.site import SynapseRequest +from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol @@ -65,9 +65,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # databases objects are the same. self.worker_hs.get_datastore().db = hs.get_datastore().db - self.test_handler = Mock( - wraps=TestReplicationDataHandler(self.worker_hs.get_datastore()) - ) + self.test_handler = self._build_replication_data_handler() self.worker_hs.replication_data_handler = self.test_handler repl_handler = ReplicationCommandHandler(self.worker_hs) @@ -78,6 +76,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._client_transport = None self._server_transport = None + def _build_replication_data_handler(self): + return TestReplicationDataHandler(self.worker_hs.get_datastore()) + def reconnect(self): if self._client_transport: self.client.close() @@ -174,22 +175,28 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): class TestReplicationDataHandler(ReplicationDataHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" - def __init__(self, hs): - super().__init__(hs) - self.streams = set() - self._received_rdata_rows = [] + def __init__(self, store: BaseSlavedStore): + super().__init__(store) + + # streams to subscribe to: map from stream id to position + self.stream_positions = {} # type: Dict[str, int] + + # list of received (stream_name, token, row) tuples + self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] def get_streams_to_replicate(self): - positions = {s: 0 for s in self.streams} - for stream, token, _ in self._received_rdata_rows: - if stream in self.streams: - positions[stream] = max(token, positions.get(stream, 0)) - return positions + return self.stream_positions async def on_rdata(self, stream_name, token, rows): await super().on_rdata(stream_name, token, rows) for r in rows: - self._received_rdata_rows.append((stream_name, token, r)) + self.received_rdata_rows.append((stream_name, token, r)) + + if ( + stream_name in self.stream_positions + and token > self.stream_positions[stream_name] + ): + self.stream_positions[stream_name] = token @attr.s() @@ -221,7 +228,7 @@ class _PushHTTPChannel(HTTPChannel): super().__init__() self.reactor = reactor - self._pull_to_push_producer = None + self._pull_to_push_producer = None # type: Optional[_PullToPushProducer] def registerProducer(self, producer, streaming): # Convert pull producers to push producer. diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py new file mode 100644 index 0000000000..1fa28084f9 --- /dev/null +++ b/tests/replication/tcp/streams/test_events.py @@ -0,0 +1,417 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from synapse.api.constants import EventTypes, Membership +from synapse.events import EventBase +from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT +from synapse.replication.tcp.streams.events import ( + EventsStreamCurrentStateRow, + EventsStreamEventRow, + EventsStreamRow, +) +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.replication.tcp.streams._base import BaseStreamTestCase +from tests.test_utils.event_injection import inject_event, inject_member_event + + +class EventsStreamTestCase(BaseStreamTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + super().prepare(reactor, clock, hs) + self.user_id = self.register_user("u1", "pass") + self.user_tok = self.login("u1", "pass") + + self.reconnect() + self.test_handler.stream_positions["events"] = 0 + + self.room_id = self.helper.create_room_as(tok=self.user_tok) + self.test_handler.received_rdata_rows.clear() + + def test_update_function_event_row_limit(self): + """Test replication with many non-state events + + Checks that all events are correctly replicated when there are lots of + event rows to be replicated. + """ + # disconnect, so that we can stack up some changes + self.disconnect() + + # generate lots of non-state events. We inject them using inject_event + # so that they are not send out over replication until we call self.replicate(). + events = [ + self._inject_test_event() + for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1) + ] + + # also one state event + state_event = self._inject_state_event() + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + received_rows = self.test_handler.received_rdata_rows + for event in events: + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, event.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertIsInstance(row, EventsStreamRow) + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state_event.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state_event.event_id) + + self.assertEqual([], received_rows) + + def test_update_function_huge_state_change(self): + """Test replication with many state events + + Ensures that all events are correctly replicated when there are lots of + state change rows to be replicated. + """ + + # we want to generate lots of state changes at a single stream ID. + # + # We do this by having two branches in the DAG. On one, we have a moderator + # which that generates lots of state; on the other, we de-op the moderator, + # thus invalidating all the state. + + OTHER_USER = "@other_user:localhost" + + # have the user join + inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + + # Update existing power levels with mod at PL50 + pls = self.helper.get_state( + self.room_id, EventTypes.PowerLevels, tok=self.user_tok + ) + pls["users"][OTHER_USER] = 50 + self.helper.send_state( + self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok, + ) + + # this is the point in the DAG where we make a fork + fork_point = self.get_success( + self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + ) # type: List[str] + + events = [ + self._inject_state_event(sender=OTHER_USER) + for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT) + ] + + self.replicate() + # all those events and state changes should have landed + self.assertGreaterEqual( + len(self.test_handler.received_rdata_rows), 2 * len(events) + ) + + # disconnect, so that we can stack up the changes + self.disconnect() + self.test_handler.received_rdata_rows.clear() + + # a state event which doesn't get rolled back, to check that the state + # before the huge update comes through ok + state1 = self._inject_state_event() + + # roll back all the state by de-modding the user + prev_events = fork_point + pls["users"][OTHER_USER] = 0 + pl_event = inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) + + # one more bit of state that doesn't get rolled back + state2 = self._inject_state_event() + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # now we should have received all the expected rows in the right order. + # + # we expect: + # + # - two rows for state1 + # - the PL event row, plus state rows for the PL event and each + # of the states that got reverted. + # - two rows for state2 + + received_rows = self.test_handler.received_rdata_rows + + # first check the first two rows, which should be state1 + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state1.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state1.event_id) + + # now the last two rows, which should be state2 + stream_name, token, row = received_rows.pop(-2) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state2.event_id) + + stream_name, token, row = received_rows.pop(-1) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state2.event_id) + + # that should leave us with the rows for the PL event + self.assertEqual(len(received_rows), len(events) + 2) + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_event.event_id) + + # the state rows are unsorted + state_rows = [] # type: List[EventsStreamCurrentStateRow] + for stream_name, token, row in received_rows: + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + state_rows.append(row.data) + + state_rows.sort(key=lambda r: r.state_key) + + sr = state_rows.pop(0) + self.assertEqual(sr.type, EventTypes.PowerLevels) + self.assertEqual(sr.event_id, pl_event.event_id) + for sr in state_rows: + self.assertEqual(sr.type, "test_state_event") + # "None" indicates the state has been deleted + self.assertIsNone(sr.event_id) + + def test_update_function_state_row_limit(self): + """Test replication with many state events over several stream ids. + """ + + # we want to generate lots of state changes, but for this test, we want to + # spread out the state changes over a few stream IDs. + # + # We do this by having two branches in the DAG. On one, we have four moderators, + # each of which that generates lots of state; on the other, we de-op the users, + # thus invalidating all the state. + + NUM_USERS = 4 + STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1 + + user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)] + + # have the users join + for u in user_ids: + inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + + # Update existing power levels with mod at PL50 + pls = self.helper.get_state( + self.room_id, EventTypes.PowerLevels, tok=self.user_tok + ) + pls["users"].update({u: 50 for u in user_ids}) + self.helper.send_state( + self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok, + ) + + # this is the point in the DAG where we make a fork + fork_point = self.get_success( + self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + ) # type: List[str] + + events = [] # type: List[EventBase] + for user in user_ids: + events.extend( + self._inject_state_event(sender=user) for _ in range(STATES_PER_USER) + ) + + self.replicate() + + # all those events and state changes should have landed + self.assertGreaterEqual( + len(self.test_handler.received_rdata_rows), 2 * len(events) + ) + + # disconnect, so that we can stack up the changes + self.disconnect() + self.test_handler.received_rdata_rows.clear() + + # now roll back all that state by de-modding the users + prev_events = fork_point + pl_events = [] + for u in user_ids: + pls["users"][u] = 0 + e = inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) + prev_events = [e.event_id] + pl_events.append(e) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + + received_rows = self.test_handler.received_rdata_rows + self.assertGreaterEqual(len(received_rows), len(events)) + for i in range(NUM_USERS): + # for each user, we expect the PL event row, followed by state rows for + # the PL event and each of the states that got reverted. + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_events[i].event_id) + + # the state rows are unsorted + state_rows = [] # type: List[EventsStreamCurrentStateRow] + for j in range(STATES_PER_USER + 1): + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + state_rows.append(row.data) + + state_rows.sort(key=lambda r: r.state_key) + + sr = state_rows.pop(0) + self.assertEqual(sr.type, EventTypes.PowerLevels) + self.assertEqual(sr.event_id, pl_events[i].event_id) + for sr in state_rows: + self.assertEqual(sr.type, "test_state_event") + # "None" indicates the state has been deleted + self.assertIsNone(sr.event_id) + + self.assertEqual([], received_rows) + + event_count = 0 + + def _inject_test_event( + self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs + ) -> EventBase: + if sender is None: + sender = self.user_id + + if body is None: + body = "event %i" % (self.event_count,) + self.event_count += 1 + + return inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_event", + content={"body": body}, + **kwargs + ) + + def _inject_state_event( + self, + body: Optional[str] = None, + state_key: Optional[str] = None, + sender: Optional[str] = None, + ) -> EventBase: + if sender is None: + sender = self.user_id + + if state_key is None: + state_key = "state_%i" % (self.event_count,) + self.event_count += 1 + + if body is None: + body = "state event %s" % (state_key,) + + return inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_state_event", + state_key=state_key, + content={"body": body}, + ) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index a0206f7363..c122b8589c 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -12,6 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# type: ignore + +from mock import Mock + from synapse.replication.tcp.streams._base import ReceiptsStream from tests.replication.tcp.streams._base import BaseStreamTestCase @@ -20,11 +25,14 @@ USER_ID = "@feeling:blue" class ReceiptsStreamTestCase(BaseStreamTestCase): + def _build_replication_data_handler(self): + return Mock(wraps=super()._build_replication_data_handler()) + def test_receipt(self): self.reconnect() # make the client subscribe to the receipts stream - self.test_handler.streams.add("receipts") + self.test_handler.stream_positions.update({"receipts": 0}) # tell the master to send a new receipt self.get_success( diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index f0ad6402ae..4d354a9db8 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from synapse.handlers.typing import RoomMember from synapse.replication.http import streams from synapse.replication.tcp.streams import TypingStream @@ -26,6 +28,9 @@ class TypingStreamTestCase(BaseStreamTestCase): streams.register_servlets, ] + def _build_replication_data_handler(self): + return Mock(wraps=super()._build_replication_data_handler()) + def test_typing(self): typing = self.hs.get_typing_handler() @@ -33,8 +38,8 @@ class TypingStreamTestCase(BaseStreamTestCase): self.reconnect() - # make the client subscribe to the receipts stream - self.test_handler.streams.add("typing") + # make the client subscribe to the typing stream + self.test_handler.stream_positions.update({"typing": 0}) typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) @@ -75,6 +80,6 @@ class TypingStreamTestCase(BaseStreamTestCase): stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: TypingStream.TypingStreamRow + row = rdata_rows[0] self.assertEqual(room_id, row.room_id) self.assertEqual([], row.user_ids) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 371637618d..22d734e763 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -39,7 +39,7 @@ class RestHelper(object): resource = attr.ib() auth_user_id = attr.ib() - def create_room_as(self, room_creator, is_public=True, tok=None): + def create_room_as(self, room_creator=None, is_public=True, tok=None): temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index a7310cf12a..7b345b03bb 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2019 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,3 +17,22 @@ """ Utilities for running the unit tests """ +from typing import Awaitable, TypeVar + +TV = TypeVar("TV") + + +def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: + """Get the result from an Awaitable which should have completed + + Asserts that the given awaitable has a result ready, and returns its value + """ + i = awaitable.__await__() + try: + next(i) + except StopIteration as e: + # awaitable returned a result + return e.value + + # if next didn't raise, the awaitable hasn't completed. + raise Exception("awaitable has not yet completed") diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py new file mode 100644 index 0000000000..8f6872761a --- /dev/null +++ b/tests/test_utils/event_injection.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import synapse.server +from synapse.api.constants import EventTypes +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase +from synapse.types import Collection + +from tests.test_utils import get_awaitable_result + + +""" +Utility functions for poking events into the storage of the server under test. +""" + + +def inject_member_event( + hs: synapse.server.HomeServer, + room_id: str, + sender: str, + membership: str, + target: Optional[str] = None, + extra_content: Optional[dict] = None, + **kwargs +) -> EventBase: + """Inject a membership event into a room.""" + if target is None: + target = sender + + content = {"membership": membership} + if extra_content: + content.update(extra_content) + + return inject_event( + hs, + room_id=room_id, + type=EventTypes.Member, + sender=sender, + state_key=target, + content=content, + **kwargs + ) + + +def inject_event( + hs: synapse.server.HomeServer, + room_version: Optional[str] = None, + prev_event_ids: Optional[Collection[str]] = None, + **kwargs +) -> EventBase: + """Inject a generic event into a room + + Args: + hs: the homeserver under test + room_version: the version of the room we're inserting into. + if not specified, will be looked up + prev_event_ids: prev_events for the event. If not specified, will be looked up + kwargs: fields for the event to be created + """ + test_reactor = hs.get_reactor() + + if room_version is None: + d = hs.get_datastore().get_room_version_id(kwargs["room_id"]) + test_reactor.advance(0) + room_version = get_awaitable_result(d) + + builder = hs.get_event_builder_factory().for_room_version( + KNOWN_ROOM_VERSIONS[room_version], kwargs + ) + d = hs.get_event_creation_handler().create_new_client_event( + builder, prev_event_ids=prev_event_ids + ) + test_reactor.advance(0) + event, context = get_awaitable_result(d) + + d = hs.get_storage().persistence.persist_event(event, context) + test_reactor.advance(0) + get_awaitable_result(d) + + return event diff --git a/tests/unittest.py b/tests/unittest.py index 27af5228fe..6b6f224e9c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -32,7 +32,6 @@ from twisted.python.threadpool import ThreadPool from twisted.trial import unittest from synapse.api.constants import EventTypes, Membership -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.federation.transport import server as federation_server @@ -55,6 +54,7 @@ from tests.server import ( render, setup_test_homeserver, ) +from tests.test_utils import event_injection from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -596,36 +596,14 @@ class HomeserverTestCase(TestCase): """ Inject a membership event into a room. + Deprecated: use event_injection.inject_room_member directly + Args: room: Room ID to inject the event into. user: MXID of the user to inject the membership for. membership: The membership type. """ - event_builder_factory = self.hs.get_event_builder_factory() - event_creation_handler = self.hs.get_event_creation_handler() - - room_version = self.get_success( - self.hs.get_datastore().get_room_version_id(room) - ) - - builder = event_builder_factory.for_room_version( - KNOWN_ROOM_VERSIONS[room_version], - { - "type": EventTypes.Member, - "sender": user, - "state_key": user, - "room_id": room, - "content": {"membership": membership}, - }, - ) - - event, context = self.get_success( - event_creation_handler.create_new_client_event(builder) - ) - - self.get_success( - self.hs.get_storage().persistence.persist_event(event, context) - ) + event_injection.inject_member_event(self.hs, room, user, membership) class FederatingHomeserverTestCase(HomeserverTestCase): diff --git a/tox.ini b/tox.ini index 31011d7436..2630857436 100644 --- a/tox.ini +++ b/tox.ini @@ -204,6 +204,8 @@ commands = mypy \ synapse/storage/database.py \ synapse/streams \ synapse/util/caches/stream_change_cache.py \ + tests/replication/tcp/streams \ + tests/test_utils \ tests/util/test_stream_change_cache.py # To find all folders that pass mypy you run: -- cgit 1.5.1