From 58f61f10f780a5f9e6be99f4072c24442594d597 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Fri, 4 Sep 2020 12:22:23 +0100 Subject: Catch-up after Federation Outage (split, 1) (#8230) Signed-off-by: Olivier Wilkinson (reivilibre) --- tests/federation/test_federation_catch_up.py | 82 ++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 tests/federation/test_federation_catch_up.py (limited to 'tests') diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py new file mode 100644 index 0000000000..73c51c9d6c --- /dev/null +++ b/tests/federation/test_federation_catch_up.py @@ -0,0 +1,82 @@ +from mock import Mock + +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.test_utils import event_injection, make_awaitable +from tests.unittest import FederatingHomeserverTestCase, override_config + + +class FederationCatchUpTestCases(FederatingHomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + federation_transport_client=Mock(spec=["send_transaction"]), + ) + + def prepare(self, reactor, clock, hs): + # stub out get_current_hosts_in_room + state_handler = hs.get_state_handler() + + # This mock is crucial for destination_rooms to be populated. + state_handler.get_current_hosts_in_room = Mock( + return_value=make_awaitable(["test", "host2"]) + ) + + def get_destination_room(self, room: str, destination: str = "host2") -> dict: + """ + Gets the destination_rooms entry for a (destination, room_id) pair. + + Args: + room: room ID + destination: what destination, default is "host2" + + Returns: + Dictionary of { event_id: str, stream_ordering: int } + """ + event_id, stream_ordering = self.get_success( + self.hs.get_datastore().db_pool.execute( + "test:get_destination_rooms", + None, + """ + SELECT event_id, stream_ordering + FROM destination_rooms dr + JOIN events USING (stream_ordering) + WHERE dr.destination = ? AND dr.room_id = ? + """, + destination, + room, + ) + )[0] + return {"event_id": event_id, "stream_ordering": stream_ordering} + + @override_config({"send_federation": True}) + def test_catch_up_destination_rooms_tracking(self): + """ + Tests that we populate the `destination_rooms` table as needed. + """ + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room = self.helper.create_room_as("u1", tok=u1_token) + + self.get_success( + event_injection.inject_member_event(self.hs, room, "@user:host2", "join") + ) + + event_id_1 = self.helper.send(room, "wombats!", tok=u1_token)["event_id"] + + row_1 = self.get_destination_room(room) + + event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"] + + row_2 = self.get_destination_room(room) + + # check: events correctly registered in order + self.assertEqual(row_1["event_id"], event_id_1) + self.assertEqual(row_2["event_id"], event_id_2) + self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1) -- cgit 1.5.1 From 765437df54a0e74f37a2c65515f54755096eeecd Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Mon, 7 Sep 2020 10:11:38 +0100 Subject: Add tests for `last_successful_stream_ordering` (#8258) --- changelog.d/8258.misc | 1 + tests/federation/test_federation_catch_up.py | 76 ++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+) create mode 100644 changelog.d/8258.misc (limited to 'tests') diff --git a/changelog.d/8258.misc b/changelog.d/8258.misc new file mode 100644 index 0000000000..3c27803be4 --- /dev/null +++ b/changelog.d/8258.misc @@ -0,0 +1 @@ +Track the `stream_ordering` of the last successfully-sent event to every destination, so we can use this information to 'catch up' a remote server after an outage. diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 73c51c9d6c..6cdcc378f0 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -28,6 +28,24 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): return_value=make_awaitable(["test", "host2"]) ) + # whenever send_transaction is called, record the pdu data + self.pdus = [] + self.failed_pdus = [] + self.is_online = True + self.hs.get_federation_transport_client().send_transaction.side_effect = ( + self.record_transaction + ) + + async def record_transaction(self, txn, json_cb): + if self.is_online: + data = json_cb() + self.pdus.extend(data["pdus"]) + return {} + else: + data = json_cb() + self.failed_pdus.extend(data["pdus"]) + raise IOError("Failed to connect because this is a test!") + def get_destination_room(self, room: str, destination: str = "host2") -> dict: """ Gets the destination_rooms entry for a (destination, room_id) pair. @@ -80,3 +98,61 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): self.assertEqual(row_1["event_id"], event_id_1) self.assertEqual(row_2["event_id"], event_id_2) self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1) + + @override_config({"send_federation": True}) + def test_catch_up_last_successful_stream_ordering_tracking(self): + """ + Tests that we populate the `destination_rooms` table as needed. + """ + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room = self.helper.create_room_as("u1", tok=u1_token) + + # take the remote offline + self.is_online = False + + self.get_success( + event_injection.inject_member_event(self.hs, room, "@user:host2", "join") + ) + + self.helper.send(room, "wombats!", tok=u1_token) + self.pump() + + lsso_1 = self.get_success( + self.hs.get_datastore().get_destination_last_successful_stream_ordering( + "host2" + ) + ) + + self.assertIsNone( + lsso_1, + "There should be no last successful stream ordering for an always-offline destination", + ) + + # bring the remote online + self.is_online = True + + event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"] + + lsso_2 = self.get_success( + self.hs.get_datastore().get_destination_last_successful_stream_ordering( + "host2" + ) + ) + row_2 = self.get_destination_room(room) + + self.assertEqual( + self.pdus[0]["content"]["body"], + "rabbits!", + "Test fault: didn't receive the right PDU", + ) + self.assertEqual( + row_2["event_id"], + event_id_2, + "Test fault: destination_rooms not updated correctly", + ) + self.assertEqual( + lsso_2, + row_2["stream_ordering"], + "Send succeeded but not marked as last_successful_stream_ordering", + ) -- cgit 1.5.1 From 68cdb3708e2cefabf81e6e027c8dbd92169207a2 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 8 Sep 2020 11:05:59 +0100 Subject: Rename 'populate_stats_process_rooms_2' background job back to 'populate_stats_process_rooms' again (#8243) Fixes https://github.com/matrix-org/synapse/issues/8238 Alongside the delta file, some changes were also necessary to the codebase to remove references to the now defunct `populate_stats_process_rooms_2` background job. Thankfully the latter doesn't seem to have made it into any documentation yet :) --- changelog.d/8243.misc | 1 + .../58/16populate_stats_process_rooms_fix.sql | 22 +++++++++++++ synapse/storage/databases/main/stats.py | 36 ++++------------------ tests/handlers/test_stats.py | 15 ++++----- 4 files changed, 35 insertions(+), 39 deletions(-) create mode 100644 changelog.d/8243.misc create mode 100644 synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql (limited to 'tests') diff --git a/changelog.d/8243.misc b/changelog.d/8243.misc new file mode 100644 index 0000000000..f7375d32d3 --- /dev/null +++ b/changelog.d/8243.misc @@ -0,0 +1 @@ +Remove the 'populate_stats_process_rooms_2' background job and restore functionality to 'populate_stats_process_rooms'. \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql new file mode 100644 index 0000000000..55f5d0f732 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +-- This delta file fixes a regression introduced by 58/12room_stats.sql, removing the hacky +-- populate_stats_process_rooms_2 background job and restores the functionality under the +-- original name. +-- See https://github.com/matrix-org/synapse/issues/8238 for details + +DELETE FROM background_updates WHERE update_name = 'populate_stats_process_rooms'; +UPDATE background_updates SET update_name = 'populate_stats_process_rooms' + WHERE update_name = 'populate_stats_process_rooms_2'; diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 55a250ef06..30840dbbaa 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -73,9 +73,6 @@ class StatsStore(StateDeltasStore): self.db_pool.updates.register_background_update_handler( "populate_stats_process_rooms", self._populate_stats_process_rooms ) - self.db_pool.updates.register_background_update_handler( - "populate_stats_process_rooms_2", self._populate_stats_process_rooms_2 - ) self.db_pool.updates.register_background_update_handler( "populate_stats_process_users", self._populate_stats_process_users ) @@ -148,31 +145,10 @@ class StatsStore(StateDeltasStore): return len(users_to_work_on) async def _populate_stats_process_rooms(self, progress, batch_size): - """ - This was a background update which regenerated statistics for rooms. - - It has been replaced by StatsStore._populate_stats_process_rooms_2. This background - job has been scheduled to run as part of Synapse v1.0.0, and again now. To ensure - someone upgrading from Date: Tue, 8 Sep 2020 07:26:55 -0400 Subject: Allow for make_awaitable's return value to be re-used. (#8261) --- changelog.d/8261.misc | 1 + tests/federation/test_complexity.py | 30 ++++++++-------------- tests/federation/test_federation_sender.py | 2 +- tests/handlers/test_auth.py | 20 +++++++-------- tests/handlers/test_register.py | 10 ++++---- tests/handlers/test_typing.py | 2 +- tests/replication/test_federation_sender_shard.py | 10 ++++---- tests/rest/admin/test_user.py | 6 ++--- .../test_resource_limits_server_notices.py | 14 ++++------ tests/storage/test_client_ips.py | 2 +- tests/storage/test_monthly_active_users.py | 16 +++--------- tests/test_utils/__init__.py | 13 +++++++--- 12 files changed, 56 insertions(+), 70 deletions(-) create mode 100644 changelog.d/8261.misc (limited to 'tests') diff --git a/changelog.d/8261.misc b/changelog.d/8261.misc new file mode 100644 index 0000000000..bc91e9375c --- /dev/null +++ b/changelog.d/8261.misc @@ -0,0 +1 @@ +Simplify tests that mock asynchronous functions. diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 3d880c499d..1471cc1a28 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -77,11 +77,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock( - side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) - ) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -110,11 +108,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock( - side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) - ) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -150,11 +146,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(None) - ) + fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) handler.federation_handler.do_invite_join = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + return_value=make_awaitable(("", 1)) ) # Artificially raise the complexity @@ -208,11 +202,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock( - side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) - ) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -240,11 +232,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock( - side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) - ) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 5f512ff8bf..917762e6b6 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -34,7 +34,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): mock_state_handler = Mock(spec=["get_current_hosts_in_room"]) # Ensure a new Awaitable is created for each call. - mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable( + mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable( ["test", "host2"] ) return self.setup_test_homeserver( diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index c7efd3822d..97877c2e42 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -143,7 +143,7 @@ class AuthTestCase(unittest.TestCase): def test_mau_limits_exceeded_large(self): self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.large_number_of_users) + return_value=make_awaitable(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): @@ -154,7 +154,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.large_number_of_users) + return_value=make_awaitable(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -169,7 +169,7 @@ class AuthTestCase(unittest.TestCase): # If not in monthly active cohort self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) + return_value=make_awaitable(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -179,7 +179,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) + return_value=make_awaitable(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -189,10 +189,10 @@ class AuthTestCase(unittest.TestCase): ) # If in monthly active cohort self.hs.get_datastore().user_last_seen_monthly_active = Mock( - side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec()) + return_value=make_awaitable(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) + return_value=make_awaitable(self.auth_blocking._max_mau_value) ) yield defer.ensureDeferred( self.auth_handler.get_access_token_for_user_id( @@ -200,10 +200,10 @@ class AuthTestCase(unittest.TestCase): ) ) self.hs.get_datastore().user_last_seen_monthly_active = Mock( - side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec()) + return_value=make_awaitable(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) + return_value=make_awaitable(self.auth_blocking._max_mau_value) ) yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id( @@ -216,7 +216,7 @@ class AuthTestCase(unittest.TestCase): self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.small_number_of_users) + return_value=make_awaitable(self.small_number_of_users) ) # Ensure does not raise exception yield defer.ensureDeferred( @@ -226,7 +226,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.small_number_of_users) + return_value=make_awaitable(self.small_number_of_users) ) yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index eddf5e2498..cb7c0ed51a 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -100,7 +100,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_get_or_create_user_mau_not_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.count_monthly_users = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value - 1) + return_value=make_awaitable(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "c", "User")) @@ -108,7 +108,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_get_or_create_user_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.lots_of_users) + return_value=make_awaitable(self.lots_of_users) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -116,7 +116,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.max_mau_value) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -126,14 +126,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_register_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.lots_of_users) + return_value=make_awaitable(self.lots_of_users) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError ) self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.max_mau_value) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 7bf15c4ba9..ae6bc24f4c 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -116,7 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): retry_timings_res ) - self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable( + self.datastore.get_device_updates_by_remote.return_value = make_awaitable( (0, []) ) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 8b4982ecb1..1d7edee5ba 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -45,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new event. """ mock_client = Mock(spec=["put_json"]) - mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({}) + mock_client.put_json.return_value = make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", @@ -73,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new events. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) + mock_client1.put_json.return_value = make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -85,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) + mock_client2.put_json.return_value = make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -136,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new typing EDUs. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) + mock_client1.put_json.return_value = make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -148,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) + mock_client2.put_json.return_value = make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 160c630235..b8b7758d24 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -337,7 +337,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -591,7 +591,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -631,7 +631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 973338ea71..6382b19dc3 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -67,7 +67,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): raise Exception("Failed to find reference to ResourceLimitsServerNotices") self._rlsn._store.user_last_seen_monthly_active = Mock( - side_effect=lambda user_id: make_awaitable(1000) + return_value=make_awaitable(1000) ) self._rlsn._server_notices_manager.send_notice = Mock( return_value=defer.succeed(Mock()) @@ -80,9 +80,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): return_value=defer.succeed("!something:localhost") ) self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) - self._rlsn._store.get_tags_for_room = Mock( - side_effect=lambda user_id, room_id: make_awaitable({}) - ) + self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) @override_config({"hs_disabled": True}) def test_maybe_send_server_notice_disabled_hs(self): @@ -158,7 +156,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self._rlsn._store.user_last_seen_monthly_active = Mock( - side_effect=lambda user_id: make_awaitable(None) + return_value=make_awaitable(None) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -261,12 +259,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.user_id = "@user_id:test" def test_server_notice_only_sent_once(self): - self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(1000) - ) + self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000)) self.store.user_last_seen_monthly_active = Mock( - side_effect=lambda user_id: make_awaitable(1000) + return_value=make_awaitable(1000) ) # Call the function multiple times to ensure we only send the notice once diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 370c247e16..755c70db31 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -154,7 +154,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): user_id = "@user:server" self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(lots_of_users) + return_value=make_awaitable(lots_of_users) ) self.get_success( self.store.insert_client_ip( diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 9870c74883..643072bbaf 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -231,9 +231,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ) self.get_success(d) - self.store.upsert_monthly_active_user = Mock( - side_effect=lambda user_id: make_awaitable(None) - ) + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) d = self.store.populate_monthly_active_users(user_id) self.get_success(d) @@ -241,9 +239,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_not_called() def test_populate_monthly_users_should_update(self): - self.store.upsert_monthly_active_user = Mock( - side_effect=lambda user_id: make_awaitable(None) - ) + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) self.store.is_trial_user = Mock(return_value=defer.succeed(False)) @@ -256,9 +252,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_called_once() def test_populate_monthly_users_should_not_update(self): - self.store.upsert_monthly_active_user = Mock( - side_effect=lambda user_id: make_awaitable(None) - ) + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.user_last_seen_monthly_active = Mock( @@ -344,9 +338,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self): - self.store.upsert_monthly_active_user = Mock( - side_effect=lambda user_id: make_awaitable(None) - ) + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) self.get_success(self.store.populate_monthly_active_users("@user:sever")) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 508aeba078..a298cc0fd3 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -17,6 +17,7 @@ """ Utilities for running the unit tests """ +from asyncio import Future from typing import Any, Awaitable, TypeVar TV = TypeVar("TV") @@ -38,6 +39,12 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: raise Exception("awaitable has not yet completed") -async def make_awaitable(result: Any): - """Create an awaitable that just returns a result.""" - return result +def make_awaitable(result: Any) -> Awaitable[Any]: + """ + Makes an awaitable, suitable for mocking an `async` function. + This uses Futures as they can be awaited multiple times so can be returned + to multiple callers. + """ + future = Future() # type: ignore + future.set_result(result) + return future -- cgit 1.5.1 From deedb917325ea9ce8085df45dd925b8d583fd661 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 8 Sep 2020 14:26:54 +0100 Subject: Fix `MultiWriterIdGenerator.current_position`. (#8257) It did not correctly handle IDs finishing being persisted out of order, resulting in the `current_position` lagging until new IDs are persisted. --- changelog.d/8257.misc | 1 + synapse/storage/util/id_generators.py | 43 +++++++++++++++++++++++++----- tests/storage/test_id_generators.py | 50 +++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 6 deletions(-) create mode 100644 changelog.d/8257.misc (limited to 'tests') diff --git a/changelog.d/8257.misc b/changelog.d/8257.misc new file mode 100644 index 0000000000..47ac583eb4 --- /dev/null +++ b/changelog.d/8257.misc @@ -0,0 +1 @@ +Fix non-user visible bug in implementation of `MultiWriterIdGenerator.get_current_token_for_writer`. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index b7eb4f8ac9..2a66b3ad4e 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -224,6 +224,10 @@ class MultiWriterIdGenerator: # should be less than the minimum of this set (if not empty). self._unfinished_ids = set() # type: Set[int] + # Set of local IDs that we've processed that are larger than the current + # position, due to there being smaller unpersisted IDs. + self._finished_ids = set() # type: Set[int] + # We track the max position where we know everything before has been # persisted. This is done by a) looking at the min across all instances # and b) noting that if we have seen a run of persisted positions @@ -348,17 +352,44 @@ class MultiWriterIdGenerator: def _mark_id_as_finished(self, next_id: int): """The ID has finished being processed so we should advance the - current poistion if possible. + current position if possible. """ with self._lock: self._unfinished_ids.discard(next_id) + self._finished_ids.add(next_id) + + new_cur = None + + if self._unfinished_ids: + # If there are unfinished IDs then the new position will be the + # largest finished ID less than the minimum unfinished ID. + + finished = set() + + min_unfinshed = min(self._unfinished_ids) + for s in self._finished_ids: + if s < min_unfinshed: + if new_cur is None or new_cur < s: + new_cur = s + else: + finished.add(s) + + # We clear these out since they're now all less than the new + # position. + self._finished_ids = finished + else: + # There are no unfinished IDs so the new position is simply the + # largest finished one. + new_cur = max(self._finished_ids) + + # We clear these out since they're now all less than the new + # position. + self._finished_ids.clear() - # Figure out if its safe to advance the position by checking there - # aren't any lower allocated IDs that are yet to finish. - if all(c > next_id for c in self._unfinished_ids): + if new_cur: curr = self._current_positions.get(self._instance_name, 0) - self._current_positions[self._instance_name] = max(curr, next_id) + self._current_positions[self._instance_name] = max(curr, new_cur) self._add_persisted_position(next_id) @@ -428,7 +459,7 @@ class MultiWriterIdGenerator: # We move the current min position up if the minimum current positions # of all instances is higher (since by definition all positions less # that that have been persisted). - min_curr = min(self._current_positions.values()) + min_curr = min(self._current_positions.values(), default=0) self._persisted_upto_position = max(min_curr, self._persisted_upto_position) # We now iterate through the seen positions, discarding those that are diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index f0a8e32f1e..20636fc400 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -122,6 +122,56 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) + def test_out_of_order_finish(self): + """Test that IDs persisted out of order are correctly handled + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + ctx1 = self.get_success(id_gen.get_next()) + ctx2 = self.get_success(id_gen.get_next()) + ctx3 = self.get_success(id_gen.get_next()) + ctx4 = self.get_success(id_gen.get_next()) + + s1 = ctx1.__enter__() + s2 = ctx2.__enter__() + s3 = ctx3.__enter__() + s4 = ctx4.__enter__() + + self.assertEqual(s1, 8) + self.assertEqual(s2, 9) + self.assertEqual(s3, 10) + self.assertEqual(s4, 11) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + ctx2.__exit__(None, None, None) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + ctx1.__exit__(None, None, None) + + self.assertEqual(id_gen.get_positions(), {"master": 9}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) + + ctx4.__exit__(None, None, None) + + self.assertEqual(id_gen.get_positions(), {"master": 9}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) + + ctx3.__exit__(None, None, None) + + self.assertEqual(id_gen.get_positions(), {"master": 11}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) + def test_multi_instance(self): """Test that reads and writes from multiple processes are handled correctly. -- cgit 1.5.1 From 094896a69d44a69946df099da59adec0b52da0ac Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 8 Sep 2020 16:03:09 +0100 Subject: Add a config option for validating 'next_link' parameters against a domain whitelist (#8275) This is a config option ported over from DINUM's Sydent: https://github.com/matrix-org/sydent/pull/285 They've switched to validating 3PIDs via Synapse rather than Sydent, and would like to retain this functionality. This original purpose for this change is phishing prevention. This solution could also potentially be replaced by a similar one to https://github.com/matrix-org/synapse/pull/8004, but across all `*/submit_token` endpoint. This option may still be useful to enterprise even with that safeguard in place though, if they want to be absolutely sure that their employees don't follow links to other domains. --- changelog.d/8275.feature | 1 + docs/sample_config.yaml | 18 +++++ synapse/config/server.py | 33 ++++++++- synapse/rest/client/v2_alpha/account.py | 66 +++++++++++++++--- tests/rest/client/v2_alpha/test_account.py | 103 +++++++++++++++++++++++++++-- 5 files changed, 204 insertions(+), 17 deletions(-) create mode 100644 changelog.d/8275.feature (limited to 'tests') diff --git a/changelog.d/8275.feature b/changelog.d/8275.feature new file mode 100644 index 0000000000..17549c3df3 --- /dev/null +++ b/changelog.d/8275.feature @@ -0,0 +1 @@ +Add a config option to specify a whitelist of domains that a user can be redirected to after validating their email or phone number. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 3528d9e11f..994b0a62c4 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -432,6 +432,24 @@ retention: # #request_token_inhibit_3pid_errors: true +# A list of domains that the domain portion of 'next_link' parameters +# must match. +# +# This parameter is optionally provided by clients while requesting +# validation of an email or phone number, and maps to a link that +# users will be automatically redirected to after validation +# succeeds. Clients can make use this parameter to aid the validation +# process. +# +# The whitelist is applied whether the homeserver or an +# identity server is handling validation. +# +# The default value is no whitelist functionality; all domains are +# allowed. Setting this value to an empty list will instead disallow +# all domains. +# +#next_link_domain_whitelist: ["matrix.org"] + ## TLS ## diff --git a/synapse/config/server.py b/synapse/config/server.py index e85c6a0840..532b910470 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -19,7 +19,7 @@ import logging import os.path import re from textwrap import indent -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Set import attr import yaml @@ -542,6 +542,19 @@ class ServerConfig(Config): users_new_default_push_rules ) # type: set + # Whitelist of domain names that given next_link parameters must have + next_link_domain_whitelist = config.get( + "next_link_domain_whitelist" + ) # type: Optional[List[str]] + + self.next_link_domain_whitelist = None # type: Optional[Set[str]] + if next_link_domain_whitelist is not None: + if not isinstance(next_link_domain_whitelist, list): + raise ConfigError("'next_link_domain_whitelist' must be a list") + + # Turn the list into a set to improve lookup speed. + self.next_link_domain_whitelist = set(next_link_domain_whitelist) + def has_tls_listener(self) -> bool: return any(listener.tls for listener in self.listeners) @@ -1014,6 +1027,24 @@ class ServerConfig(Config): # act as if no error happened and return a fake session ID ('sid') to clients. # #request_token_inhibit_3pid_errors: true + + # A list of domains that the domain portion of 'next_link' parameters + # must match. + # + # This parameter is optionally provided by clients while requesting + # validation of an email or phone number, and maps to a link that + # users will be automatically redirected to after validation + # succeeds. Clients can make use this parameter to aid the validation + # process. + # + # The whitelist is applied whether the homeserver or an + # identity server is handling validation. + # + # The default value is no whitelist functionality; all domains are + # allowed. Setting this value to an empty list will instead disallow + # all domains. + # + #next_link_domain_whitelist: ["matrix.org"] """ % locals() ) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 3481477731..455051ac46 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -17,6 +17,11 @@ import logging import random from http import HTTPStatus +from typing import TYPE_CHECKING +from urllib.parse import urlparse + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer from synapse.api.constants import LoginType from synapse.api.errors import ( @@ -98,6 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) + # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to # an email address which is controlled by the attacker but which, after @@ -446,6 +454,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) + existing_user_id = await self.store.get_user_id_by_threepid("email", email) if existing_user_id is not None: @@ -517,6 +528,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) + existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) if existing_user_id is not None: @@ -603,15 +617,10 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): # Perform a 302 redirect if next_link is set if next_link: - if next_link.startswith("file:///"): - logger.warning( - "Not redirecting to next_link as it is a local file: address" - ) - else: - request.setResponseCode(302) - request.setHeader("Location", next_link) - finish_request(request) - return None + request.setResponseCode(302) + request.setHeader("Location", next_link) + finish_request(request) + return None # Otherwise show the success template html = self.config.email_add_threepid_template_success_html_content @@ -875,6 +884,45 @@ class ThreepidDeleteRestServlet(RestServlet): return 200, {"id_server_unbind_result": id_server_unbind_result} +def assert_valid_next_link(hs: "HomeServer", next_link: str): + """ + Raises a SynapseError if a given next_link value is invalid + + next_link is valid if the scheme is http(s) and the next_link.domain_whitelist config + option is either empty or contains a domain that matches the one in the given next_link + + Args: + hs: The homeserver object + next_link: The next_link value given by the client + + Raises: + SynapseError: If the next_link is invalid + """ + valid = True + + # Parse the contents of the URL + next_link_parsed = urlparse(next_link) + + # Scheme must not point to the local drive + if next_link_parsed.scheme == "file": + valid = False + + # If the domain whitelist is set, the domain must be in it + if ( + valid + and hs.config.next_link_domain_whitelist is not None + and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist + ): + valid = False + + if not valid: + raise SynapseError( + 400, + "'next_link' domain not included in whitelist, or not http(s)", + errcode=Codes.INVALID_PARAM, + ) + + class WhoamiRestServlet(RestServlet): PATTERNS = client_patterns("/account/whoami$") diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 152a5182fa..0a51aeff92 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -14,11 +14,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json import os import re from email.parser import Parser +from typing import Optional import pkg_resources @@ -29,6 +29,7 @@ from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register from tests import unittest +from tests.unittest import override_config class PasswordResetTestCase(unittest.HomeserverTestCase): @@ -668,16 +669,104 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def _request_token(self, email, client_secret): + @override_config({"next_link_domain_whitelist": None}) + def test_next_link(self): + """Tests a valid next_link parameter value with no whitelist (good case)""" + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/a/good/site", + expect_code=200, + ) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link_exotic_protocol(self): + """Tests using a esoteric protocol as a next_link parameter value. + Someone may be hosting a client on IPFS etc. + """ + self._request_token( + "something@example.com", + "some_secret", + next_link="some-protocol://abcdefghijklmopqrstuvwxyz", + expect_code=200, + ) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link_file_uri(self): + """Tests next_link parameters cannot be file URI""" + # Attempt to use a next_link value that points to the local disk + self._request_token( + "something@example.com", + "some_secret", + next_link="file:///host/path", + expect_code=400, + ) + + @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) + def test_next_link_domain_whitelist(self): + """Tests next_link parameters must fit the whitelist if provided""" + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/some/good/page", + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.org/some/also/good/page", + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://bad.example.org/some/bad/page", + expect_code=400, + ) + + @override_config({"next_link_domain_whitelist": []}) + def test_empty_next_link_domain_whitelist(self): + """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially + disallowed + """ + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/a/page", + expect_code=400, + ) + + def _request_token( + self, + email: str, + client_secret: str, + next_link: Optional[str] = None, + expect_code: int = 200, + ) -> str: + """Request a validation token to add an email address to a user's account + + Args: + email: The email address to validate + client_secret: A secret string + next_link: A link to redirect the user to after validation + expect_code: Expected return code of the call + + Returns: + The ID of the new threepid validation session + """ + body = {"client_secret": client_secret, "email": email, "send_attempt": 1} + if next_link: + body["next_link"] = next_link + request, channel = self.make_request( - "POST", - b"account/3pid/email/requestToken", - {"client_secret": client_secret, "email": email, "send_attempt": 1}, + "POST", b"account/3pid/email/requestToken", body, ) self.render(request) - self.assertEquals(200, channel.code, channel.result) + self.assertEquals(expect_code, channel.code, channel.result) - return channel.json_body["sid"] + return channel.json_body.get("sid") def _request_token_invalid_email( self, email, expected_errcode, expected_error, client_secret="foobar", -- cgit 1.5.1 From a5370072b53e7ea3ebbd9404ee4133508c2d55b2 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Wed, 9 Sep 2020 11:39:39 +0100 Subject: Don't remember `enabled` of deleted push rules and properly return 404 for missing push rules in `.../actions` and `.../enabled` (#7796) Signed-off-by: Olivier Wilkinson (reivilibre) Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- changelog.d/7796.bugfix | 1 + synapse/rest/client/v1/push_rule.py | 15 +- synapse/storage/databases/main/push_rule.py | 131 +++++- .../58/10_pushrules_enabled_delete_obsolete.sql | 28 ++ tests/rest/client/v1/test_push_rule_attrs.py | 448 +++++++++++++++++++++ 5 files changed, 610 insertions(+), 13 deletions(-) create mode 100644 changelog.d/7796.bugfix create mode 100644 synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql create mode 100644 tests/rest/client/v1/test_push_rule_attrs.py (limited to 'tests') diff --git a/changelog.d/7796.bugfix b/changelog.d/7796.bugfix new file mode 100644 index 0000000000..65e5eb42a2 --- /dev/null +++ b/changelog.d/7796.bugfix @@ -0,0 +1 @@ +Fix inconsistent handling of non-existent push rules, and stop tracking the `enabled` state of removed push rules. diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index e781a3bcf4..ddf8ed5e9c 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -163,6 +163,18 @@ class PushRuleRestServlet(RestServlet): self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) async def set_rule_attr(self, user_id, spec, val): + if spec["attr"] not in ("enabled", "actions"): + # for the sake of potential future expansion, shouldn't report + # 404 in the case of an unknown request so check it corresponds to + # a known attribute first. + raise UnrecognizedRequestError() + + namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + rule_id = spec["rule_id"] + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if namespaced_rule_id not in BASE_RULE_IDS: + raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,)) if spec["attr"] == "enabled": if isinstance(val, dict) and "enabled" in val: val = val["enabled"] @@ -171,9 +183,8 @@ class PushRuleRestServlet(RestServlet): # This should *actually* take a dict, but many clients pass # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) return await self.store.set_push_rule_enabled( - user_id, namespaced_rule_id, val + user_id, namespaced_rule_id, val, is_default_rule ) elif spec["attr"] == "actions": actions = val.get("actions") diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 0de802a86b..9790a31998 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -13,11 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import abc import logging from typing import List, Tuple, Union +from synapse.api.errors import NotFoundError, StoreError from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json @@ -27,6 +27,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.pusher import PusherWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore +from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util import json_encoder @@ -540,6 +541,25 @@ class PushRuleStore(PushRulesWorkerStore): }, ) + # ensure we have a push_rules_enable row + # enabledness defaults to true + if isinstance(self.database_engine, PostgresEngine): + sql = """ + INSERT INTO push_rules_enable (id, user_name, rule_id, enabled) + VALUES (?, ?, ?, ?) + ON CONFLICT DO NOTHING + """ + elif isinstance(self.database_engine, Sqlite3Engine): + sql = """ + INSERT OR IGNORE INTO push_rules_enable (id, user_name, rule_id, enabled) + VALUES (?, ?, ?, ?) + """ + else: + raise RuntimeError("Unknown database engine") + + new_enable_id = self._push_rules_enable_id_gen.get_next() + txn.execute(sql, (new_enable_id, user_id, rule_id, 1)) + async def delete_push_rule(self, user_id: str, rule_id: str) -> None: """ Delete a push rule. Args specify the row to be deleted and can be @@ -552,6 +572,12 @@ class PushRuleStore(PushRulesWorkerStore): """ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): + # we don't use simple_delete_one_txn because that would fail if the + # user did not have a push_rule_enable row. + self.db_pool.simple_delete_txn( + txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id} + ) + self.db_pool.simple_delete_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} ) @@ -570,10 +596,29 @@ class PushRuleStore(PushRulesWorkerStore): event_stream_ordering, ) - async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None: + async def set_push_rule_enabled( + self, user_id: str, rule_id: str, enabled: bool, is_default_rule: bool + ) -> None: + """ + Sets the `enabled` state of a push rule. + + Args: + user_id: the user ID of the user who wishes to enable/disable the rule + e.g. '@tina:example.org' + rule_id: the full rule ID of the rule to be enabled/disabled + e.g. 'global/override/.m.rule.roomnotif' + or 'global/override/myCustomRule' + enabled: True if the rule is to be enabled, False if it is to be + disabled + is_default_rule: True if and only if this is a server-default rule. + This skips the check for existence (as only user-created rules + are always stored in the database `push_rules` table). + + Raises: + NotFoundError if the rule does not exist. + """ with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() - await self.db_pool.runInteraction( "_set_push_rule_enabled_txn", self._set_push_rule_enabled_txn, @@ -582,12 +627,47 @@ class PushRuleStore(PushRulesWorkerStore): user_id, rule_id, enabled, + is_default_rule, ) def _set_push_rule_enabled_txn( - self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + enabled, + is_default_rule, ): new_id = self._push_rules_enable_id_gen.get_next() + + if not is_default_rule: + # first check it exists; we need to lock for key share so that a + # transaction that deletes the push rule will conflict with this one. + # We also need a push_rule_enable row to exist for every push_rules + # row, otherwise it is possible to simultaneously delete a push rule + # (that has no _enable row) and enable it, resulting in a dangling + # _enable row. To solve this: we either need to use SERIALISABLE or + # ensure we always have a push_rule_enable row for every push_rule + # row. We chose the latter. + for_key_share = "FOR KEY SHARE" + if not isinstance(self.database_engine, PostgresEngine): + # For key share is not applicable/available on SQLite + for_key_share = "" + sql = ( + """ + SELECT 1 FROM push_rules + WHERE user_name = ? AND rule_id = ? + %s + """ + % for_key_share + ) + txn.execute(sql, (user_id, rule_id)) + if txn.fetchone() is None: + # needed to set NOT_FOUND code. + raise NotFoundError("Push rule does not exist.") + self.db_pool.simple_upsert_txn( txn, "push_rules_enable", @@ -606,8 +686,30 @@ class PushRuleStore(PushRulesWorkerStore): ) async def set_push_rule_actions( - self, user_id, rule_id, actions, is_default_rule + self, + user_id: str, + rule_id: str, + actions: List[Union[dict, str]], + is_default_rule: bool, ) -> None: + """ + Sets the `actions` state of a push rule. + + Will throw NotFoundError if the rule does not exist; the Code for this + is NOT_FOUND. + + Args: + user_id: the user ID of the user who wishes to enable/disable the rule + e.g. '@tina:example.org' + rule_id: the full rule ID of the rule to be enabled/disabled + e.g. 'global/override/.m.rule.roomnotif' + or 'global/override/myCustomRule' + actions: A list of actions (each action being a dict or string), + e.g. ["notify", {"set_tweak": "highlight", "value": false}] + is_default_rule: True if and only if this is a server-default rule. + This skips the check for existence (as only user-created rules + are always stored in the database `push_rules` table). + """ actions_json = json_encoder.encode(actions) def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): @@ -629,12 +731,19 @@ class PushRuleStore(PushRulesWorkerStore): update_stream=False, ) else: - self.db_pool.simple_update_one_txn( - txn, - "push_rules", - {"user_name": user_id, "rule_id": rule_id}, - {"actions": actions_json}, - ) + try: + self.db_pool.simple_update_one_txn( + txn, + "push_rules", + {"user_name": user_id, "rule_id": rule_id}, + {"actions": actions_json}, + ) + except StoreError as serr: + if serr.code == 404: + # this sets the NOT_FOUND error Code + raise NotFoundError("Push rule does not exist") + else: + raise self._insert_push_rules_update_txn( txn, diff --git a/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql b/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql new file mode 100644 index 0000000000..847aebd85e --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql @@ -0,0 +1,28 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + Delete stuck 'enabled' bits that correspond to deleted or non-existent push rules. + We ignore rules that are server-default rules because they are not defined + in the `push_rules` table. +**/ + +DELETE FROM push_rules_enable WHERE + rule_id NOT LIKE 'global/%/.m.rule.%' + AND NOT EXISTS ( + SELECT 1 FROM push_rules + WHERE push_rules.user_name = push_rules_enable.user_name + AND push_rules.rule_id = push_rules_enable.rule_id + ); diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py new file mode 100644 index 0000000000..081052f6a6 --- /dev/null +++ b/tests/rest/client/v1/test_push_rule_attrs.py @@ -0,0 +1,448 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import synapse +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login, push_rule, room + +from tests.unittest import HomeserverTestCase + + +class PushRuleAttributesTestCase(HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + push_rule.register_servlets, + ] + hijack_auth = False + + def test_enabled_on_creation(self): + """ + Tests the GET and PUT of push rules' `enabled` endpoints. + Tests that a rule is enabled upon creation, even though a rule with that + ruleId existed previously and was disabled. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + request, channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # GET enabled for that new rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], True) + + def test_enabled_on_recreation(self): + """ + Tests the GET and PUT of push rules' `enabled` endpoints. + Tests that a rule is enabled upon creation, even if a rule with that + ruleId existed previously and was disabled. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + request, channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # disable the rule + request, channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": False}, + access_token=token, + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # check rule disabled + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], False) + + # DELETE the rule + request, channel = self.make_request( + "DELETE", "/pushrules/global/override/best.friend", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # PUT a new rule + request, channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # GET enabled for that new rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], True) + + def test_enabled_disable(self): + """ + Tests the GET and PUT of push rules' `enabled` endpoints. + Tests that a rule is disabled and enabled when we ask for it. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + request, channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # disable the rule + request, channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": False}, + access_token=token, + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # check rule disabled + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], False) + + # re-enable the rule + request, channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": True}, + access_token=token, + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # check rule enabled + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], True) + + def test_enabled_404_when_get_non_existent(self): + """ + Tests that `enabled` gives 404 when the rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # check 404 for never-heard-of rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + # PUT a new rule + request, channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # GET enabled for that new rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # DELETE the rule + request, channel = self.make_request( + "DELETE", "/pushrules/global/override/best.friend", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # check 404 for deleted rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_enabled_404_when_get_non_existent_server_rule(self): + """ + Tests that `enabled` gives 404 when the server-default rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # check 404 for never-heard-of rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_enabled_404_when_put_non_existent_rule(self): + """ + Tests that `enabled` gives 404 when we put to a rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + request, channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": True}, + access_token=token, + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_enabled_404_when_put_non_existent_server_rule(self): + """ + Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + request, channel = self.make_request( + "PUT", + "/pushrules/global/override/.m.muahahah/enabled", + {"enabled": True}, + access_token=token, + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_get(self): + """ + Tests that `actions` gives you what you expect on a fresh rule. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + request, channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # GET actions for that new rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/actions", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}] + ) + + def test_actions_put(self): + """ + Tests that PUT on actions updates the value you'd get from GET. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + request, channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # change the rule actions + request, channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/actions", + {"actions": ["dont_notify"]}, + access_token=token, + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # GET actions for that new rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/actions", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["actions"], ["dont_notify"]) + + def test_actions_404_when_get_non_existent(self): + """ + Tests that `actions` gives 404 when the rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # check 404 for never-heard-of rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + # PUT a new rule + request, channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # DELETE the rule + request, channel = self.make_request( + "DELETE", "/pushrules/global/override/best.friend", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # check 404 for deleted rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_404_when_get_non_existent_server_rule(self): + """ + Tests that `actions` gives 404 when the server-default rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # check 404 for never-heard-of rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_404_when_put_non_existent_rule(self): + """ + Tests that `actions` gives 404 when putting to a rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + request, channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/actions", + {"actions": ["dont_notify"]}, + access_token=token, + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_404_when_put_non_existent_server_rule(self): + """ + Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + request, channel = self.make_request( + "PUT", + "/pushrules/global/override/.m.muahahah/actions", + {"actions": ["dont_notify"]}, + access_token=token, + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) -- cgit 1.5.1 From e7fd336a53a4ca489cdafc389b494d5477019dc0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 9 Sep 2020 16:17:50 +0100 Subject: Fixup pusher pool notifications --- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 2 +- synapse/push/emailpusher.py | 2 +- synapse/push/httppusher.py | 2 +- synapse/push/pusherpool.py | 19 ++++++++++++++++--- synapse/replication/tcp/client.py | 3 ++- tests/handlers/test_typing.py | 1 + 7 files changed, 23 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 43f2986f89..74d7ac8a67 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2970,7 +2970,7 @@ class FederationHandler(BaseHandler): event, event_stream_id, max_stream_id, extra_users=extra_users ) - await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) + await self.pusher_pool.on_new_notifications(max_stream_id) async def _clean_room_for_join(self, room_id: str) -> None: """Called to clean up any data in DB for a given room, ready for the diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8a7b4916cd..d1556659e3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1145,7 +1145,7 @@ class EventCreationHandler: # If there's an expiry timestamp on the event, schedule its expiry. self._message_handler.maybe_schedule_expiry(event) - await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) + await self.pusher_pool.on_new_notifications(max_stream_id) def _notify(): try: diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index b7ea4438e0..28bd8ab748 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -91,7 +91,7 @@ class EmailPusher: pass self.timed_call = None - def on_new_notifications(self, min_stream_ordering, max_stream_ordering): + def on_new_notifications(self, max_stream_ordering): if self.max_stream_ordering: self.max_stream_ordering = max( max_stream_ordering, self.max_stream_ordering diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index f21fa9b659..26706bf3e1 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -114,7 +114,7 @@ class HttpPusher: if should_check_for_notifs: self._start_processing() - def on_new_notifications(self, min_stream_ordering, max_stream_ordering): + def on_new_notifications(self, max_stream_ordering): self.max_stream_ordering = max( max_stream_ordering, self.max_stream_ordering or 0 ) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 3c3262a88c..fa8473bf8d 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -64,6 +64,12 @@ class PusherPool: self._pusher_shard_config = hs.config.push.pusher_shard_config self._instance_name = hs.get_instance_name() + # Record the last stream ID that we were poked about so we can get + # changes since then. We set this to the current max stream ID on + # startup as every individual pusher will have checked for changes on + # startup. + self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering() + # map from user id to app_id:pushkey to pusher self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]] @@ -178,20 +184,27 @@ class PusherPool: ) await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - async def on_new_notifications(self, min_stream_id, max_stream_id): + async def on_new_notifications(self, max_stream_id): if not self.pushers: # nothing to do here. return + if max_stream_id < self._last_room_stream_id_seen: + # Nothing to do + return + + prev_stream_id = self._last_room_stream_id_seen + self._last_room_stream_id_seen = max_stream_id + try: users_affected = await self.store.get_push_action_users_in_range( - min_stream_id, max_stream_id + prev_stream_id, max_stream_id ) for u in users_affected: if u in self.pushers: for p in self.pushers[u].values(): - p.on_new_notifications(min_stream_id, max_stream_id) + p.on_new_notifications(max_stream_id) except Exception: logger.exception("Exception in pusher on_new_notifications") diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index d6ecf5b327..ccd3147dfd 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -154,7 +154,8 @@ class ReplicationDataHandler: max_token = self.store.get_room_max_stream_ordering() self.notifier.on_new_room_event(event, token, max_token, extra_users) - await self.pusher_pool.on_new_notifications(token, token) + max_token = self.store.get_room_max_stream_ordering() + await self.pusher_pool.on_new_notifications(max_token) # Notify any waiting deferreds. The list is ordered by position so we # just iterate through the list until we reach a position that is diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index ae6bc24f4c..f306a09bfa 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -80,6 +80,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "get_user_directory_stream_pos", "get_current_state_deltas", "get_device_updates_by_remote", + "get_room_max_stream_ordering", ] ) -- cgit 1.5.1 From dc9dcdbd59d4f839c7a96780f7464460fae27851 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 9 Sep 2020 16:19:22 +0100 Subject: Revert "Fixup pusher pool notifications" This reverts commit e7fd336a53a4ca489cdafc389b494d5477019dc0. --- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 2 +- synapse/push/emailpusher.py | 2 +- synapse/push/httppusher.py | 2 +- synapse/push/pusherpool.py | 19 +++---------------- synapse/replication/tcp/client.py | 3 +-- tests/handlers/test_typing.py | 1 - 7 files changed, 8 insertions(+), 23 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 74d7ac8a67..43f2986f89 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2970,7 +2970,7 @@ class FederationHandler(BaseHandler): event, event_stream_id, max_stream_id, extra_users=extra_users ) - await self.pusher_pool.on_new_notifications(max_stream_id) + await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) async def _clean_room_for_join(self, room_id: str) -> None: """Called to clean up any data in DB for a given room, ready for the diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d1556659e3..8a7b4916cd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1145,7 +1145,7 @@ class EventCreationHandler: # If there's an expiry timestamp on the event, schedule its expiry. self._message_handler.maybe_schedule_expiry(event) - await self.pusher_pool.on_new_notifications(max_stream_id) + await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): try: diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 28bd8ab748..b7ea4438e0 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -91,7 +91,7 @@ class EmailPusher: pass self.timed_call = None - def on_new_notifications(self, max_stream_ordering): + def on_new_notifications(self, min_stream_ordering, max_stream_ordering): if self.max_stream_ordering: self.max_stream_ordering = max( max_stream_ordering, self.max_stream_ordering diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 26706bf3e1..f21fa9b659 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -114,7 +114,7 @@ class HttpPusher: if should_check_for_notifs: self._start_processing() - def on_new_notifications(self, max_stream_ordering): + def on_new_notifications(self, min_stream_ordering, max_stream_ordering): self.max_stream_ordering = max( max_stream_ordering, self.max_stream_ordering or 0 ) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index fa8473bf8d..3c3262a88c 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -64,12 +64,6 @@ class PusherPool: self._pusher_shard_config = hs.config.push.pusher_shard_config self._instance_name = hs.get_instance_name() - # Record the last stream ID that we were poked about so we can get - # changes since then. We set this to the current max stream ID on - # startup as every individual pusher will have checked for changes on - # startup. - self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering() - # map from user id to app_id:pushkey to pusher self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]] @@ -184,27 +178,20 @@ class PusherPool: ) await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - async def on_new_notifications(self, max_stream_id): + async def on_new_notifications(self, min_stream_id, max_stream_id): if not self.pushers: # nothing to do here. return - if max_stream_id < self._last_room_stream_id_seen: - # Nothing to do - return - - prev_stream_id = self._last_room_stream_id_seen - self._last_room_stream_id_seen = max_stream_id - try: users_affected = await self.store.get_push_action_users_in_range( - prev_stream_id, max_stream_id + min_stream_id, max_stream_id ) for u in users_affected: if u in self.pushers: for p in self.pushers[u].values(): - p.on_new_notifications(max_stream_id) + p.on_new_notifications(min_stream_id, max_stream_id) except Exception: logger.exception("Exception in pusher on_new_notifications") diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index ccd3147dfd..d6ecf5b327 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -154,8 +154,7 @@ class ReplicationDataHandler: max_token = self.store.get_room_max_stream_ordering() self.notifier.on_new_room_event(event, token, max_token, extra_users) - max_token = self.store.get_room_max_stream_ordering() - await self.pusher_pool.on_new_notifications(max_token) + await self.pusher_pool.on_new_notifications(token, token) # Notify any waiting deferreds. The list is ordered by position so we # just iterate through the list until we reach a position that is diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f306a09bfa..ae6bc24f4c 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -80,7 +80,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "get_user_directory_stream_pos", "get_current_state_deltas", "get_device_updates_by_remote", - "get_room_max_stream_ordering", ] ) -- cgit 1.5.1 From c9dbee50aefc22390f600a0219ca7fa1ae9acd88 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 9 Sep 2020 16:56:08 +0100 Subject: Fixup pusher pool notifications (#8287) `pusher_pool.on_new_notifications` expected a min and max stream ID, however that was not what we were passing in. Instead, let's just pass it the current max stream ID and have it track the last stream ID it got passed. I believe that it mostly worked as we called the function for every event. However, it would break for events that got persisted out of order, i.e, that were persisted but the max stream ID wasn't incremented as not all preceding events had finished persisting, and push for that event would be delayed until another event got pushed to the effected users. --- changelog.d/8287.bugfix | 1 + synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 2 +- synapse/push/emailpusher.py | 2 +- synapse/push/httppusher.py | 2 +- synapse/push/pusherpool.py | 19 ++++++++++++++++--- synapse/replication/tcp/client.py | 3 ++- tests/handlers/test_typing.py | 1 + 8 files changed, 24 insertions(+), 8 deletions(-) create mode 100644 changelog.d/8287.bugfix (limited to 'tests') diff --git a/changelog.d/8287.bugfix b/changelog.d/8287.bugfix new file mode 100644 index 0000000000..839781aa07 --- /dev/null +++ b/changelog.d/8287.bugfix @@ -0,0 +1 @@ +Fix edge case where push could get delayed for a user until a later event was pushed. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 43f2986f89..74d7ac8a67 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2970,7 +2970,7 @@ class FederationHandler(BaseHandler): event, event_stream_id, max_stream_id, extra_users=extra_users ) - await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) + await self.pusher_pool.on_new_notifications(max_stream_id) async def _clean_room_for_join(self, room_id: str) -> None: """Called to clean up any data in DB for a given room, ready for the diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8a7b4916cd..d1556659e3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1145,7 +1145,7 @@ class EventCreationHandler: # If there's an expiry timestamp on the event, schedule its expiry. self._message_handler.maybe_schedule_expiry(event) - await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) + await self.pusher_pool.on_new_notifications(max_stream_id) def _notify(): try: diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index b7ea4438e0..28bd8ab748 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -91,7 +91,7 @@ class EmailPusher: pass self.timed_call = None - def on_new_notifications(self, min_stream_ordering, max_stream_ordering): + def on_new_notifications(self, max_stream_ordering): if self.max_stream_ordering: self.max_stream_ordering = max( max_stream_ordering, self.max_stream_ordering diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index f21fa9b659..26706bf3e1 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -114,7 +114,7 @@ class HttpPusher: if should_check_for_notifs: self._start_processing() - def on_new_notifications(self, min_stream_ordering, max_stream_ordering): + def on_new_notifications(self, max_stream_ordering): self.max_stream_ordering = max( max_stream_ordering, self.max_stream_ordering or 0 ) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 3c3262a88c..fa8473bf8d 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -64,6 +64,12 @@ class PusherPool: self._pusher_shard_config = hs.config.push.pusher_shard_config self._instance_name = hs.get_instance_name() + # Record the last stream ID that we were poked about so we can get + # changes since then. We set this to the current max stream ID on + # startup as every individual pusher will have checked for changes on + # startup. + self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering() + # map from user id to app_id:pushkey to pusher self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]] @@ -178,20 +184,27 @@ class PusherPool: ) await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - async def on_new_notifications(self, min_stream_id, max_stream_id): + async def on_new_notifications(self, max_stream_id): if not self.pushers: # nothing to do here. return + if max_stream_id < self._last_room_stream_id_seen: + # Nothing to do + return + + prev_stream_id = self._last_room_stream_id_seen + self._last_room_stream_id_seen = max_stream_id + try: users_affected = await self.store.get_push_action_users_in_range( - min_stream_id, max_stream_id + prev_stream_id, max_stream_id ) for u in users_affected: if u in self.pushers: for p in self.pushers[u].values(): - p.on_new_notifications(min_stream_id, max_stream_id) + p.on_new_notifications(max_stream_id) except Exception: logger.exception("Exception in pusher on_new_notifications") diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index d6ecf5b327..ccd3147dfd 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -154,7 +154,8 @@ class ReplicationDataHandler: max_token = self.store.get_room_max_stream_ordering() self.notifier.on_new_room_event(event, token, max_token, extra_users) - await self.pusher_pool.on_new_notifications(token, token) + max_token = self.store.get_room_max_stream_ordering() + await self.pusher_pool.on_new_notifications(max_token) # Notify any waiting deferreds. The list is ordered by position so we # just iterate through the list until we reach a position that is diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index ae6bc24f4c..f306a09bfa 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -80,6 +80,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "get_user_directory_stream_pos", "get_current_state_deltas", "get_device_updates_by_remote", + "get_room_max_stream_ordering", ] ) -- cgit 1.5.1 From b312769c0ee2c40b1a26a6ed39ea1c8a462d4349 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Sep 2020 12:59:41 -0400 Subject: Do not error when thumbnailing invalid files (#8236) If a file cannot be thumbnailed for some reason (e.g. the file is empty), then catch the exception and convert it to a reasonable error message for the client. --- changelog.d/8236.bugfix | 1 + synapse/rest/media/v1/media_repository.py | 69 +++++++++++++++++++++++++---- synapse/rest/media/v1/thumbnail_resource.py | 5 ++- synapse/rest/media/v1/thumbnailer.py | 14 +++++- tests/rest/media/v1/test_media_storage.py | 39 +++++++++++----- 5 files changed, 106 insertions(+), 22 deletions(-) create mode 100644 changelog.d/8236.bugfix (limited to 'tests') diff --git a/changelog.d/8236.bugfix b/changelog.d/8236.bugfix new file mode 100644 index 0000000000..6f04871015 --- /dev/null +++ b/changelog.d/8236.bugfix @@ -0,0 +1 @@ +Fix a longstanding bug where files that could not be thumbnailed would result in an Internal Server Error. diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 9a1b7779f7..69f353d46f 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -53,7 +53,7 @@ from .media_storage import MediaStorage from .preview_url_resource import PreviewUrlResource from .storage_provider import StorageProviderWrapper from .thumbnail_resource import ThumbnailResource -from .thumbnailer import Thumbnailer +from .thumbnailer import Thumbnailer, ThumbnailError from .upload_resource import UploadResource logger = logging.getLogger(__name__) @@ -460,13 +460,30 @@ class MediaRepository: return t_byte_source async def generate_local_exact_thumbnail( - self, media_id, t_width, t_height, t_method, t_type, url_cache - ): + self, + media_id: str, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + url_cache: str, + ) -> Optional[str]: input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(None, media_id, url_cache=url_cache) ) - thumbnailer = Thumbnailer(input_path) + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s", + media_id, + t_method, + t_type, + e, + ) + return None + t_byte_source = await defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, @@ -506,14 +523,36 @@ class MediaRepository: return output_path + # Could not generate thumbnail. + return None + async def generate_remote_exact_thumbnail( - self, server_name, file_id, media_id, t_width, t_height, t_method, t_type - ): + self, + server_name: str, + file_id: str, + media_id: str, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + ) -> Optional[str]: input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=False) ) - thumbnailer = Thumbnailer(input_path) + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s", + media_id, + server_name, + t_method, + t_type, + e, + ) + return None + t_byte_source = await defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, @@ -559,6 +598,9 @@ class MediaRepository: return output_path + # Could not generate thumbnail. + return None + async def _generate_thumbnails( self, server_name: Optional[str], @@ -590,7 +632,18 @@ class MediaRepository: FileInfo(server_name, file_id, url_cache=url_cache) ) - thumbnailer = Thumbnailer(input_path) + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate thumbnails for remote media %s from %s using a method of %s and type of %s: %s", + media_id, + server_name, + media_type, + e, + ) + return None + m_width = thumbnailer.width m_height = thumbnailer.height diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index a83535b97b..30421b663a 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -16,6 +16,7 @@ import logging +from synapse.api.errors import SynapseError from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.servlet import parse_integer, parse_string @@ -173,7 +174,7 @@ class ThumbnailResource(DirectServeJsonResource): await respond_with_file(request, desired_type, file_path) else: logger.warning("Failed to generate thumbnail") - respond_404(request) + raise SynapseError(400, "Failed to generate thumbnail.") async def _select_or_generate_remote_thumbnail( self, @@ -235,7 +236,7 @@ class ThumbnailResource(DirectServeJsonResource): await respond_with_file(request, desired_type, file_path) else: logger.warning("Failed to generate thumbnail") - respond_404(request) + raise SynapseError(400, "Failed to generate thumbnail.") async def _respond_remote_thumbnail( self, request, server_name, media_id, width, height, method, m_type diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index d681bf7bf0..457ad6031c 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -15,7 +15,7 @@ import logging from io import BytesIO -from PIL import Image as Image +from PIL import Image logger = logging.getLogger(__name__) @@ -31,12 +31,22 @@ EXIF_TRANSPOSE_MAPPINGS = { } +class ThumbnailError(Exception): + """An error occurred generating a thumbnail.""" + + class Thumbnailer: FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} def __init__(self, input_path): - self.image = Image.open(input_path) + try: + self.image = Image.open(input_path) + except OSError as e: + # If an error occurs opening the image, a thumbnail won't be able to + # be generated. + raise ThumbnailError from e + self.width, self.height = self.image.size self.transpose_method = None try: diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index f4f3e56777..5f897d49cf 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -120,12 +120,13 @@ class _TestImage: extension = attr.ib(type=bytes) expected_cropped = attr.ib(type=Optional[bytes]) expected_scaled = attr.ib(type=Optional[bytes]) + expected_found = attr.ib(default=True, type=bool) @parameterized_class( ("test_image",), [ - # smol png + # smoll png ( _TestImage( unhexlify( @@ -161,6 +162,8 @@ class _TestImage: None, ), ), + # an empty file + (_TestImage(b"", b"image/gif", b".gif", None, None, False,),), ], ) class MediaRepoTests(unittest.HomeserverTestCase): @@ -303,12 +306,16 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) def test_thumbnail_crop(self): - self._test_thumbnail("crop", self.test_image.expected_cropped) + self._test_thumbnail( + "crop", self.test_image.expected_cropped, self.test_image.expected_found + ) def test_thumbnail_scale(self): - self._test_thumbnail("scale", self.test_image.expected_scaled) + self._test_thumbnail( + "scale", self.test_image.expected_scaled, self.test_image.expected_found + ) - def _test_thumbnail(self, method, expected_body): + def _test_thumbnail(self, method, expected_body, expected_found): params = "?width=32&height=32&method=" + method request, channel = self.make_request( "GET", self.media_id + params, shorthand=False @@ -325,11 +332,23 @@ class MediaRepoTests(unittest.HomeserverTestCase): ) self.pump() - self.assertEqual(channel.code, 200) - if expected_body is not None: + if expected_found: + self.assertEqual(channel.code, 200) + if expected_body is not None: + self.assertEqual( + channel.result["body"], expected_body, channel.result["body"] + ) + else: + # ensure that the result is at least some valid image + Image.open(BytesIO(channel.result["body"])) + else: + # A 404 with a JSON body. + self.assertEqual(channel.code, 404) self.assertEqual( - channel.result["body"], expected_body, channel.result["body"] + channel.json_body, + { + "errcode": "M_NOT_FOUND", + "error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']" + % method, + }, ) - else: - # ensure that the result is at least some valid image - Image.open(BytesIO(channel.result["body"])) -- cgit 1.5.1 From a3a90ee031d3942c04ab0d985678caf30a94f9e8 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 10 Sep 2020 11:45:12 +0100 Subject: Show a confirmation page during user password reset (#8004) This PR adds a confirmation step to resetting your user password between clicking the link in your email and your password actually being reset. This is to better align our password reset flow with the industry standard of requiring a confirmation from the user after email validation. --- UPGRADE.rst | 24 ++++ changelog.d/8004.feature | 1 + docs/sample_config.yaml | 10 +- synapse/api/urls.py | 1 + synapse/app/homeserver.py | 10 ++ synapse/config/emailconfig.py | 12 +- synapse/push/mailer.py | 2 +- .../res/templates/password_reset_confirmation.html | 16 +++ synapse/rest/__init__.py | 6 +- synapse/rest/client/v2_alpha/account.py | 76 ------------ synapse/rest/synapse/__init__.py | 14 +++ synapse/rest/synapse/client/__init__.py | 14 +++ synapse/rest/synapse/client/password_reset.py | 127 +++++++++++++++++++++ tests/rest/client/v2_alpha/test_account.py | 29 ++++- tests/server.py | 15 ++- tests/unittest.py | 4 + 16 files changed, 271 insertions(+), 90 deletions(-) create mode 100644 changelog.d/8004.feature create mode 100644 synapse/res/templates/password_reset_confirmation.html create mode 100644 synapse/rest/synapse/__init__.py create mode 100644 synapse/rest/synapse/client/__init__.py create mode 100644 synapse/rest/synapse/client/password_reset.py (limited to 'tests') diff --git a/UPGRADE.rst b/UPGRADE.rst index 77be1b2952..1e4da98afe 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -88,6 +88,30 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.21.0 +==================== + +New HTML templates +------------------ + +A new HTML template, +`password_reset_confirmation.html `_, +has been added to the ``synapse/res/templates`` directory. If you are using a +custom template directory, you may want to copy the template over and modify it. + +Note that as of v1.20.0, templates do not need to be included in custom template +directories for Synapse to start. The default templates will be used if a custom +template cannot be found. + +This page will appear to the user after clicking a password reset link that has +been emailed to them. + +To complete password reset, the page must include a way to make a `POST` +request to +``/_synapse/client/password_reset/{medium}/submit_token`` +with the query parameters from the original link, presented as a URL-encoded form. See the file +itself for more details. + Upgrading to v1.18.0 ==================== diff --git a/changelog.d/8004.feature b/changelog.d/8004.feature new file mode 100644 index 0000000000..a91b75e0e0 --- /dev/null +++ b/changelog.d/8004.feature @@ -0,0 +1 @@ +Require the user to confirm that their password should be reset after clicking the email confirmation link. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 994b0a62c4..2a5b2e0935 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2039,9 +2039,13 @@ email: # * The contents of password reset emails sent by the homeserver: # 'password_reset.html' and 'password_reset.txt' # - # * HTML pages for success and failure that a user will see when they follow - # the link in the password reset email: 'password_reset_success.html' and - # 'password_reset_failure.html' + # * An HTML page that a user will see when they follow the link in the password + # reset email. The user will be asked to confirm the action before their + # password is reset: 'password_reset_confirmation.html' + # + # * HTML pages for success and failure that a user will see when they confirm + # the password reset flow using the page above: 'password_reset_success.html' + # and 'password_reset_failure.html' # # * The contents of address verification emails sent during registration: # 'registration.html' and 'registration.txt' diff --git a/synapse/api/urls.py b/synapse/api/urls.py index bbfccf955e..6379c86dde 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -21,6 +21,7 @@ from urllib.parse import urlencode from synapse.config import ConfigError +SYNAPSE_CLIENT_API_PREFIX = "/_synapse/client" CLIENT_API_PREFIX = "/_matrix/client" FEDERATION_PREFIX = "/_matrix/federation" FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1" diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 6014adc850..b08319ca77 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -48,6 +48,7 @@ from synapse.api.urls import ( from synapse.app import _base from synapse.app._base import listen_ssl, listen_tcp, quit_with_error from synapse.config._base import ConfigError +from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.homeserver import HomeServerConfig from synapse.config.server import ListenerConfig from synapse.federation.transport.server import TransportLayerServer @@ -209,6 +210,15 @@ class SynapseHomeServer(HomeServer): resources["/_matrix/saml2"] = SAML2Resource(self) + if self.get_config().threepid_behaviour_email == ThreepidBehaviour.LOCAL: + from synapse.rest.synapse.client.password_reset import ( + PasswordResetSubmitTokenResource, + ) + + resources[ + "/_synapse/client/password_reset/email/submit_token" + ] = PasswordResetSubmitTokenResource(self) + if name == "consent": from synapse.rest.consent.consent_resource import ConsentResource diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 7a796996c0..72b42bfd62 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -228,6 +228,7 @@ class EmailConfig(Config): self.email_registration_template_text, self.email_add_threepid_template_html, self.email_add_threepid_template_text, + self.email_password_reset_template_confirmation_html, self.email_password_reset_template_failure_html, self.email_registration_template_failure_html, self.email_add_threepid_template_failure_html, @@ -242,6 +243,7 @@ class EmailConfig(Config): registration_template_text, add_threepid_template_html, add_threepid_template_text, + "password_reset_confirmation.html", password_reset_template_failure_html, registration_template_failure_html, add_threepid_template_failure_html, @@ -404,9 +406,13 @@ class EmailConfig(Config): # * The contents of password reset emails sent by the homeserver: # 'password_reset.html' and 'password_reset.txt' # - # * HTML pages for success and failure that a user will see when they follow - # the link in the password reset email: 'password_reset_success.html' and - # 'password_reset_failure.html' + # * An HTML page that a user will see when they follow the link in the password + # reset email. The user will be asked to confirm the action before their + # password is reset: 'password_reset_confirmation.html' + # + # * HTML pages for success and failure that a user will see when they confirm + # the password reset flow using the page above: 'password_reset_success.html' + # and 'password_reset_failure.html' # # * The contents of address verification emails sent during registration: # 'registration.html' and 'registration.txt' diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 6c57854018..455a1acb46 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -123,7 +123,7 @@ class Mailer: params = {"token": token, "client_secret": client_secret, "sid": sid} link = ( self.hs.config.public_baseurl - + "_matrix/client/unstable/password_reset/email/submit_token?%s" + + "_synapse/client/password_reset/email/submit_token?%s" % urllib.parse.urlencode(params) ) diff --git a/synapse/res/templates/password_reset_confirmation.html b/synapse/res/templates/password_reset_confirmation.html new file mode 100644 index 0000000000..def4b5162b --- /dev/null +++ b/synapse/res/templates/password_reset_confirmation.html @@ -0,0 +1,16 @@ + + + + +
+ + + + +

You have requested to reset your Matrix account password. Click the link below to confirm this action.

+ If you did not mean to do this, please close this page and your password will not be changed.

+

+
+ + + diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 87f927890c..40f5c32db2 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -13,8 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import synapse.rest.admin from synapse.http.server import JsonResource +from synapse.rest import admin from synapse.rest.client import versions from synapse.rest.client.v1 import ( directory, @@ -123,9 +123,7 @@ class ClientRestResource(JsonResource): password_policy.register_servlets(hs, client_resource) # moving to /_synapse/admin - synapse.rest.admin.register_servlets_for_client_rest_resource( - hs, client_resource - ) + admin.register_servlets_for_client_rest_resource(hs, client_resource) # unstable shared_rooms.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 455051ac46..c6cb9deb2b 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -152,81 +152,6 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): return 200, ret -class PasswordResetSubmitTokenServlet(RestServlet): - """Handles 3PID validation token submission""" - - PATTERNS = client_patterns( - "/password_reset/(?P[^/]*)/submit_token$", releases=(), unstable=True - ) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super(PasswordResetSubmitTokenServlet, self).__init__() - self.hs = hs - self.auth = hs.get_auth() - self.config = hs.config - self.clock = hs.get_clock() - self.store = hs.get_datastore() - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self._failure_email_template = ( - self.config.email_password_reset_template_failure_html - ) - - async def on_GET(self, request, medium): - # We currently only handle threepid token submissions for email - if medium != "email": - raise SynapseError( - 400, "This medium is currently not supported for password resets" - ) - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "Password reset emails have been disabled due to lack of an email config" - ) - raise SynapseError( - 400, "Email-based password resets are disabled on this server" - ) - - sid = parse_string(request, "sid", required=True) - token = parse_string(request, "token", required=True) - client_secret = parse_string(request, "client_secret", required=True) - assert_valid_client_secret(client_secret) - - # Attempt to validate a 3PID session - try: - # Mark the session as valid - next_link = await self.store.validate_threepid_session( - sid, client_secret, token, self.clock.time_msec() - ) - - # Perform a 302 redirect if next_link is set - if next_link: - if next_link.startswith("file:///"): - logger.warning( - "Not redirecting to next_link as it is a local file: address" - ) - else: - request.setResponseCode(302) - request.setHeader("Location", next_link) - finish_request(request) - return None - - # Otherwise show the success template - html = self.config.email_password_reset_template_success_html_content - status_code = 200 - except ThreepidValidationError as e: - status_code = e.code - - # Show a failure page with a reason - template_vars = {"failure_reason": e.msg} - html = self._failure_email_template.render(**template_vars) - - respond_with_html(request, status_code, html) - - class PasswordRestServlet(RestServlet): PATTERNS = client_patterns("/account/password$") @@ -938,7 +863,6 @@ class WhoamiRestServlet(RestServlet): def register_servlets(hs, http_server): EmailPasswordRequestTokenRestServlet(hs).register(http_server) - PasswordResetSubmitTokenServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) EmailThreepidRequestTokenRestServlet(hs).register(http_server) diff --git a/synapse/rest/synapse/__init__.py b/synapse/rest/synapse/__init__.py new file mode 100644 index 0000000000..c0b733488b --- /dev/null +++ b/synapse/rest/synapse/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py new file mode 100644 index 0000000000..c0b733488b --- /dev/null +++ b/synapse/rest/synapse/client/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py new file mode 100644 index 0000000000..9e4fbc0cbd --- /dev/null +++ b/synapse/rest/synapse/client/password_reset.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Tuple + +from twisted.web.http import Request + +from synapse.api.errors import ThreepidValidationError +from synapse.config.emailconfig import ThreepidBehaviour +from synapse.http.server import DirectServeHtmlResource +from synapse.http.servlet import parse_string +from synapse.util.stringutils import assert_valid_client_secret + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class PasswordResetSubmitTokenResource(DirectServeHtmlResource): + """Handles 3PID validation token submission + + This resource gets mounted under /_synapse/client/password_reset/email/submit_token + """ + + isLeaf = 1 + + def __init__(self, hs: "HomeServer"): + """ + Args: + hs: server + """ + super().__init__() + + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + self._local_threepid_handling_disabled_due_to_email_config = ( + hs.config.local_threepid_handling_disabled_due_to_email_config + ) + self._confirmation_email_template = ( + hs.config.email_password_reset_template_confirmation_html + ) + self._email_password_reset_template_success_html = ( + hs.config.email_password_reset_template_success_html_content + ) + self._failure_email_template = ( + hs.config.email_password_reset_template_failure_html + ) + + # This resource should not be mounted if threepid behaviour is not LOCAL + assert hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL + + async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]: + sid = parse_string(request, "sid", required=True) + token = parse_string(request, "token", required=True) + client_secret = parse_string(request, "client_secret", required=True) + assert_valid_client_secret(client_secret) + + # Show a confirmation page, just in case someone accidentally clicked this link when + # they didn't mean to + template_vars = { + "sid": sid, + "token": token, + "client_secret": client_secret, + } + return ( + 200, + self._confirmation_email_template.render(**template_vars).encode("utf-8"), + ) + + async def _async_render_POST(self, request: Request) -> Tuple[int, bytes]: + sid = parse_string(request, "sid", required=True) + token = parse_string(request, "token", required=True) + client_secret = parse_string(request, "client_secret", required=True) + + # Attempt to validate a 3PID session + try: + # Mark the session as valid + next_link = await self.store.validate_threepid_session( + sid, client_secret, token, self.clock.time_msec() + ) + + # Perform a 302 redirect if next_link is set + if next_link: + if next_link.startswith("file:///"): + logger.warning( + "Not redirecting to next_link as it is a local file: address" + ) + else: + next_link_bytes = next_link.encode("utf-8") + request.setHeader("Location", next_link_bytes) + return ( + 302, + ( + b'You are being redirected to %s.' + % (next_link_bytes, next_link_bytes) + ), + ) + + # Otherwise show the success template + html_bytes = self._email_password_reset_template_success_html.encode( + "utf-8" + ) + status_code = 200 + except ThreepidValidationError as e: + status_code = e.code + + # Show a failure page with a reason + template_vars = {"failure_reason": e.msg} + html_bytes = self._failure_email_template.render(**template_vars).encode( + "utf-8" + ) + + return status_code, html_bytes diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 0a51aeff92..93f899d861 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -19,6 +19,7 @@ import os import re from email.parser import Parser from typing import Optional +from urllib.parse import urlencode import pkg_resources @@ -27,6 +28,7 @@ from synapse.api.constants import LoginType, Membership from synapse.api.errors import Codes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register +from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from tests import unittest from tests.unittest import override_config @@ -70,6 +72,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() + self.submit_token_resource = PasswordResetSubmitTokenResource(hs) def test_basic_password_reset(self): """Test basic password reset flow @@ -251,8 +254,32 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Remove the host path = link.replace("https://example.com", "") + # Load the password reset confirmation page request, channel = self.make_request("GET", path, shorthand=False) - self.render(request) + request.render(self.submit_token_resource) + self.pump() + self.assertEquals(200, channel.code, channel.result) + + # Now POST to the same endpoint, mimicking the same behaviour as clicking the + # password reset confirm button + + # Send arguments as url-encoded form data, matching the template's behaviour + form_args = [] + for key, value_list in request.args.items(): + for value in value_list: + arg = (key, value) + form_args.append(arg) + + # Confirm the password reset + request, channel = self.make_request( + "POST", + path, + content=urlencode(form_args).encode("utf8"), + shorthand=False, + content_is_form=True, + ) + request.render(self.submit_token_resource) + self.pump() self.assertEquals(200, channel.code, channel.result) def _get_link_from_email(self): diff --git a/tests/server.py b/tests/server.py index 48e45c6c8b..61ec670155 100644 --- a/tests/server.py +++ b/tests/server.py @@ -1,6 +1,6 @@ import json import logging -from io import BytesIO +from io import SEEK_END, BytesIO import attr from zope.interface import implementer @@ -135,6 +135,7 @@ def make_request( request=SynapseRequest, shorthand=True, federation_auth_origin=None, + content_is_form=False, ): """ Make a web request using the given method and path, feed it the @@ -150,6 +151,8 @@ def make_request( with the usual REST API path, if it doesn't contain it. federation_auth_origin (bytes|None): if set to not-None, we will add a fake Authorization header pretenting to be the given server name. + content_is_form: Whether the content is URL encoded form data. Adds the + 'Content-Type': 'application/x-www-form-urlencoded' header. Returns: Tuple[synapse.http.site.SynapseRequest, channel] @@ -181,6 +184,8 @@ def make_request( req = request(channel) req.process = lambda: b"" req.content = BytesIO(content) + # Twisted expects to be at the end of the content when parsing the request. + req.content.seek(SEEK_END) req.postpath = list(map(unquote, path[1:].split(b"/"))) if access_token: @@ -195,7 +200,13 @@ def make_request( ) if content: - req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") + if content_is_form: + req.requestHeaders.addRawHeader( + b"Content-Type", b"application/x-www-form-urlencoded" + ) + else: + # Assume the body is JSON + req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") req.requestReceived(method, path, b"1.1") diff --git a/tests/unittest.py b/tests/unittest.py index 3cb55a7e96..128dd4e19c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -353,6 +353,7 @@ class HomeserverTestCase(TestCase): request: Type[T] = SynapseRequest, shorthand: bool = True, federation_auth_origin: str = None, + content_is_form: bool = False, ) -> Tuple[T, FakeChannel]: """ Create a SynapseRequest at the path using the method and containing the @@ -368,6 +369,8 @@ class HomeserverTestCase(TestCase): with the usual REST API path, if it doesn't contain it. federation_auth_origin (bytes|None): if set to not-None, we will add a fake Authorization header pretenting to be the given server name. + content_is_form: Whether the content is URL encoded form data. Adds the + 'Content-Type': 'application/x-www-form-urlencoded' header. Returns: Tuple[synapse.http.site.SynapseRequest, channel] @@ -384,6 +387,7 @@ class HomeserverTestCase(TestCase): request, shorthand, federation_auth_origin, + content_is_form, ) def render(self, request): -- cgit 1.5.1 From c312ee3cde39d9c97d3552b43533a4384321dc9e Mon Sep 17 00:00:00 2001 From: Dan Callaghan Date: Fri, 11 Sep 2020 04:49:08 +1000 Subject: Use TLSv1.2 for fake servers in tests (#8208) Some Linux distros have begun disabling TLSv1.0 and TLSv1.1 by default for security reasons, for example in Fedora 33 onwards: https://fedoraproject.org/wiki/Changes/StrongCryptoSettings2 Use TLSv1.2 for the fake TLS servers created in the test suite, to avoid failures due to OpenSSL disallowing TLSv1.0: Signed-off-by: Dan Callaghan --- changelog.d/8208.misc | 1 + tests/http/__init__.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8208.misc (limited to 'tests') diff --git a/changelog.d/8208.misc b/changelog.d/8208.misc new file mode 100644 index 0000000000..e65da88c46 --- /dev/null +++ b/changelog.d/8208.misc @@ -0,0 +1 @@ +Fix tests on distros which disable TLSv1.0. Contributed by @danc86. diff --git a/tests/http/__init__.py b/tests/http/__init__.py index 5d41443293..3e5a856584 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -145,7 +145,7 @@ class TestServerTLSConnectionFactory: self._cert_file = create_test_cert_file(sanlist) def serverConnectionForTLS(self, tlsProtocol): - ctx = SSL.Context(SSL.TLSv1_METHOD) + ctx = SSL.Context(SSL.SSLv23_METHOD) ctx.use_certificate_file(self._cert_file) ctx.use_privatekey_file(get_test_key_file()) return Connection(ctx, None) -- cgit 1.5.1 From fe8ed1b46f781faa45d1bba8f9308cf47c42010f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 11 Sep 2020 12:22:55 +0100 Subject: Make `StreamToken.room_key` be a `RoomStreamToken` instance. (#8281) --- changelog.d/8281.misc | 1 + mypy.ini | 2 + synapse/handlers/admin.py | 6 +-- synapse/handlers/device.py | 12 +++-- synapse/handlers/initial_sync.py | 4 +- synapse/handlers/message.py | 1 + synapse/handlers/pagination.py | 4 +- synapse/handlers/room.py | 15 +++---- synapse/handlers/sync.py | 11 ++--- synapse/notifier.py | 16 +++++-- synapse/storage/__init__.py | 5 ++- synapse/storage/databases/main/events.py | 21 ++++++--- synapse/storage/databases/main/stream.py | 75 ++++++++++++++------------------ synapse/storage/persist_events.py | 16 ++++--- synapse/types.py | 19 ++++---- tests/test_utils/event_injection.py | 5 ++- 16 files changed, 114 insertions(+), 99 deletions(-) create mode 100644 changelog.d/8281.misc (limited to 'tests') diff --git a/changelog.d/8281.misc b/changelog.d/8281.misc new file mode 100644 index 0000000000..74357120a7 --- /dev/null +++ b/changelog.d/8281.misc @@ -0,0 +1 @@ +Change `StreamToken.room_key` to be a `RoomStreamToken` instance. diff --git a/mypy.ini b/mypy.ini index 460392377e..7986781432 100644 --- a/mypy.ini +++ b/mypy.ini @@ -46,10 +46,12 @@ files = synapse/server_notices, synapse/spam_checker_api, synapse/state, + synapse/storage/databases/main/events.py, synapse/storage/databases/main/stream.py, synapse/storage/databases/main/ui_auth.py, synapse/storage/database.py, synapse/storage/engines, + synapse/storage/persist_events.py, synapse/storage/state.py, synapse/storage/util, synapse/streams, diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 918d0e037c..5e5a64037d 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -125,8 +125,8 @@ class AdminHandler(BaseHandler): else: stream_ordering = room.stream_ordering - from_key = str(RoomStreamToken(0, 0)) - to_key = str(RoomStreamToken(None, stream_ordering)) + from_key = RoomStreamToken(0, 0) + to_key = RoomStreamToken(None, stream_ordering) written_events = set() # Events that we've processed in this room @@ -153,7 +153,7 @@ class AdminHandler(BaseHandler): if not events: break - from_key = events[-1].internal_metadata.after + from_key = RoomStreamToken.parse(events[-1].internal_metadata.after) events = await filter_events_for_client(self.storage, user_id, events) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 643d71a710..4b0a4f96cc 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -29,6 +29,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import ( RoomStreamToken, + StreamToken, get_domain_from_id, get_verify_key_from_cross_signing_key, ) @@ -104,18 +105,15 @@ class DeviceWorkerHandler(BaseHandler): @trace @measure_func("device.get_user_ids_changed") - async def get_user_ids_changed(self, user_id, from_token): + async def get_user_ids_changed(self, user_id: str, from_token: StreamToken): """Get list of users that have had the devices updated, or have newly joined a room, that `user_id` may be interested in. - - Args: - user_id (str) - from_token (StreamToken) """ set_tag("user_id", user_id) set_tag("from_token", from_token) - now_room_key = await self.store.get_room_events_max_id() + now_room_id = self.store.get_room_max_stream_ordering() + now_room_key = RoomStreamToken(None, now_room_id) room_ids = await self.store.get_rooms_for_user(user_id) @@ -142,7 +140,7 @@ class DeviceWorkerHandler(BaseHandler): ) rooms_changed.update(event.room_id for event in member_events) - stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream + stream_ordering = from_token.room_key.stream possibly_changed = set(changed) possibly_left = set() diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index ddb8f0712b..ba4828c713 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage.roommember import RoomsForUser from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, StreamToken, UserID +from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID from synapse.util import unwrapFirstError from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.response_cache import ResponseCache @@ -167,7 +167,7 @@ class InitialSyncHandler(BaseHandler): self.state_handler.get_current_state, event.room_id ) elif event.membership == Membership.LEAVE: - room_end_token = "s%d" % (event.stream_ordering,) + room_end_token = RoomStreamToken(None, event.stream_ordering,) deferred_room_state = run_in_background( self.state_store.get_state_for_events, [event.event_id] ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 276de8f8d0..e54e2b322b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -973,6 +973,7 @@ class EventCreationHandler: This should only be run on the instance in charge of persisting events. """ assert self._is_event_writer + assert self.storage.persistence is not None if ratelimit: # We check if this is a room admin redacting an event so that we diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index ec17d3d888..d929a68f7d 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -344,7 +344,7 @@ class PaginationHandler: # gets called. raise Exception("limit not set") - room_token = RoomStreamToken.parse(from_token.room_key) + room_token = from_token.room_key with await self.pagination_lock.read(room_id): ( @@ -381,7 +381,7 @@ class PaginationHandler: if leave_token.topological < max_topo: from_token = from_token.copy_and_replace( - "room_key", leave_token_str + "room_key", leave_token ) await self.hs.get_handlers().federation_handler.maybe_backfill( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a29305f655..53d85ab97d 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1091,20 +1091,19 @@ class RoomEventSource: async def get_new_events( self, user: UserID, - from_key: str, + from_key: RoomStreamToken, limit: int, room_ids: List[str], is_guest: bool, explicit_room_id: Optional[str] = None, - ) -> Tuple[List[EventBase], str]: + ) -> Tuple[List[EventBase], RoomStreamToken]: # We just ignore the key for now. to_key = self.get_current_key() - from_token = RoomStreamToken.parse(from_key) - if from_token.topological: + if from_key.topological: logger.warning("Stream has topological part!!!! %r", from_key) - from_key = "s%s" % (from_token.stream,) + from_key = RoomStreamToken(None, from_key.stream) app_service = self.store.get_app_service_by_user_id(user.to_string()) if app_service: @@ -1133,14 +1132,14 @@ class RoomEventSource: events[:] = events[:limit] if events: - end_key = events[-1].internal_metadata.after + end_key = RoomStreamToken.parse(events[-1].internal_metadata.after) else: end_key = to_key return (events, end_key) - def get_current_key(self) -> str: - return "s%d" % (self.store.get_room_max_stream_ordering(),) + def get_current_key(self) -> RoomStreamToken: + return RoomStreamToken(None, self.store.get_room_max_stream_ordering()) def get_current_key_for_room(self, room_id: str) -> Awaitable[str]: return self.store.get_room_events_max_id(room_id) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index cc47e8b62c..a615c7c2f0 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -378,7 +378,7 @@ class SyncHandler: sync_config = sync_result_builder.sync_config with Measure(self.clock, "ephemeral_by_room"): - typing_key = since_token.typing_key if since_token else "0" + typing_key = since_token.typing_key if since_token else 0 room_ids = sync_result_builder.joined_room_ids @@ -402,7 +402,7 @@ class SyncHandler: event_copy = {k: v for (k, v) in event.items() if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) - receipt_key = since_token.receipt_key if since_token else "0" + receipt_key = since_token.receipt_key if since_token else 0 receipt_source = self.event_sources.sources["receipt"] receipts, receipt_key = await receipt_source.get_new_events( @@ -533,7 +533,7 @@ class SyncHandler: if len(recents) > timeline_limit: limited = True recents = recents[-timeline_limit:] - room_key = recents[0].internal_metadata.before + room_key = RoomStreamToken.parse(recents[0].internal_metadata.before) prev_batch_token = now_token.copy_and_replace("room_key", room_key) @@ -1322,6 +1322,7 @@ class SyncHandler: is_guest=sync_config.is_guest, include_offline=include_offline, ) + assert presence_key sync_result_builder.now_token = now_token.copy_and_replace( "presence_key", presence_key ) @@ -1484,7 +1485,7 @@ class SyncHandler: if rooms_changed: return True - stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream + stream_id = since_token.room_key.stream for room_id in sync_result_builder.joined_room_ids: if self.store.has_room_changed_since(room_id, stream_id): return True @@ -1750,7 +1751,7 @@ class SyncHandler: continue leave_token = now_token.copy_and_replace( - "room_key", "s%d" % (event.stream_ordering,) + "room_key", RoomStreamToken(None, event.stream_ordering) ) room_entries.append( RoomSyncResultBuilder( diff --git a/synapse/notifier.py b/synapse/notifier.py index 16f19c938e..12cd84b27b 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -25,6 +25,7 @@ from typing import ( Set, Tuple, TypeVar, + Union, ) from prometheus_client import Counter @@ -41,7 +42,7 @@ from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.streams.config import PaginationConfig -from synapse.types import Collection, StreamToken, UserID +from synapse.types import Collection, RoomStreamToken, StreamToken, UserID from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client @@ -111,7 +112,9 @@ class _NotifierUserStream: with PreserveLoggingContext(): self.notify_deferred = ObservableDeferred(defer.Deferred()) - def notify(self, stream_key: str, stream_id: int, time_now_ms: int): + def notify( + self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int, + ): """Notify any listeners for this user of a new event from an event source. Args: @@ -294,7 +297,12 @@ class Notifier: rooms.add(event.room_id) if users or rooms: - self.on_new_event("room_key", max_room_stream_id, users=users, rooms=rooms) + self.on_new_event( + "room_key", + RoomStreamToken(None, max_room_stream_id), + users=users, + rooms=rooms, + ) self._on_updated_room_token(max_room_stream_id) def _on_updated_room_token(self, max_room_stream_id: int): @@ -329,7 +337,7 @@ class Notifier: def on_new_event( self, stream_key: str, - new_token: int, + new_token: Union[int, RoomStreamToken], users: Collection[UserID] = [], rooms: Collection[str] = [], ): diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 8e5d78f6f7..bbff3c8d5b 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -47,6 +47,9 @@ class Storage: # interfaces. self.main = stores.main - self.persistence = EventsPersistenceStorage(hs, stores) self.purge_events = PurgeEventsStorage(hs, stores) self.state = StateGroupStorage(hs, stores) + + self.persistence = None + if stores.persist_events: + self.persistence = EventsPersistenceStorage(hs, stores) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b3d27a2ee7..9cd1403b38 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -213,7 +213,7 @@ class PersistEventsStore: Returns: Filtered event ids """ - results = [] + results = [] # type: List[str] def _get_events_which_are_prevs_txn(txn, batch): sql = """ @@ -631,7 +631,9 @@ class PersistEventsStore: ) @classmethod - def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): + def _filter_events_and_contexts_for_duplicates( + cls, events_and_contexts: List[Tuple[EventBase, EventContext]] + ) -> List[Tuple[EventBase, EventContext]]: """Ensure that we don't have the same event twice. Pick the earliest non-outlier if there is one, else the earliest one. @@ -641,7 +643,9 @@ class PersistEventsStore: Returns: list[(EventBase, EventContext)]: filtered list """ - new_events_and_contexts = OrderedDict() + new_events_and_contexts = ( + OrderedDict() + ) # type: OrderedDict[str, Tuple[EventBase, EventContext]] for event, context in events_and_contexts: prev_event_context = new_events_and_contexts.get(event.event_id) if prev_event_context: @@ -655,7 +659,12 @@ class PersistEventsStore: new_events_and_contexts[event.event_id] = (event, context) return list(new_events_and_contexts.values()) - def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): + def _update_room_depths_txn( + self, + txn, + events_and_contexts: List[Tuple[EventBase, EventContext]], + backfilled: bool, + ): """Update min_depth for each room Args: @@ -664,7 +673,7 @@ class PersistEventsStore: we are persisting backfilled (bool): True if the events were backfilled """ - depth_updates = {} + depth_updates = {} # type: Dict[str, int] for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids txn.call_after(self.store._invalidate_get_event_cache, event.event_id) @@ -1436,7 +1445,7 @@ class PersistEventsStore: Forward extremities are handled when we first start persisting the events. """ - events_by_room = {} + events_by_room = {} # type: Dict[str, List[EventBase]] for ev in events: events_by_room.setdefault(ev.room_id, []).append(ev) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 08a13a8b47..2e95518752 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -310,11 +310,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): async def get_room_events_stream_for_rooms( self, room_ids: Collection[str], - from_key: str, - to_key: str, + from_key: RoomStreamToken, + to_key: RoomStreamToken, limit: int = 0, order: str = "DESC", - ) -> Dict[str, Tuple[List[EventBase], str]]: + ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]: """Get new room events in stream ordering since `from_key`. Args: @@ -333,9 +333,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): - list of recent events in the room - stream ordering key for the start of the chunk of events returned. """ - from_id = RoomStreamToken.parse_stream_token(from_key).stream - - room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id) + room_ids = self._events_stream_cache.get_entities_changed( + room_ids, from_key.stream + ) if not room_ids: return {} @@ -364,16 +364,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return results def get_rooms_that_changed( - self, room_ids: Collection[str], from_key: str + self, room_ids: Collection[str], from_key: RoomStreamToken ) -> Set[str]: """Given a list of rooms and a token, return rooms where there may have been changes. - - Args: - room_ids - from_key: The room_key portion of a StreamToken """ - from_id = RoomStreamToken.parse_stream_token(from_key).stream + from_id = from_key.stream return { room_id for room_id in room_ids @@ -383,11 +379,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): async def get_room_events_stream_for_room( self, room_id: str, - from_key: str, - to_key: str, + from_key: RoomStreamToken, + to_key: RoomStreamToken, limit: int = 0, order: str = "DESC", - ) -> Tuple[List[EventBase], str]: + ) -> Tuple[List[EventBase], RoomStreamToken]: """Get new room events in stream ordering since `from_key`. Args: @@ -408,8 +404,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if from_key == to_key: return [], from_key - from_id = RoomStreamToken.parse_stream_token(from_key).stream - to_id = RoomStreamToken.parse_stream_token(to_key).stream + from_id = from_key.stream + to_id = to_key.stream has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) @@ -441,7 +437,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ret.reverse() if rows: - key = "s%d" % min(r.stream_ordering for r in rows) + key = RoomStreamToken(None, min(r.stream_ordering for r in rows)) else: # Assume we didn't get anything because there was nothing to # get. @@ -450,10 +446,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret, key async def get_membership_changes_for_user( - self, user_id: str, from_key: str, to_key: str + self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: - from_id = RoomStreamToken.parse_stream_token(from_key).stream - to_id = RoomStreamToken.parse_stream_token(to_key).stream + from_id = from_key.stream + to_id = to_key.stream if from_key == to_key: return [] @@ -491,8 +487,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret async def get_recent_events_for_room( - self, room_id: str, limit: int, end_token: str - ) -> Tuple[List[EventBase], str]: + self, room_id: str, limit: int, end_token: RoomStreamToken + ) -> Tuple[List[EventBase], RoomStreamToken]: """Get the most recent events in the room in topological ordering. Args: @@ -518,8 +514,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return (events, token) async def get_recent_event_ids_for_room( - self, room_id: str, limit: int, end_token: str - ) -> Tuple[List[_EventDictReturn], str]: + self, room_id: str, limit: int, end_token: RoomStreamToken + ) -> Tuple[List[_EventDictReturn], RoomStreamToken]: """Get the most recent events in the room in topological ordering. Args: @@ -535,13 +531,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if limit == 0: return [], end_token - parsed_end_token = RoomStreamToken.parse(end_token) - rows, token = await self.db_pool.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, - from_token=parsed_end_token, + from_token=end_token, limit=limit, ) @@ -619,17 +613,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): allow_none=allow_none, ) - async def get_stream_token_for_event(self, event_id: str) -> str: + async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken: """The stream token for an event Args: event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: - A "s%d" stream token. + A stream token. """ stream_id = await self.get_stream_id_for_event(event_id) - return "s%d" % (stream_id,) + return RoomStreamToken(None, stream_id) async def get_topological_token_for_event(self, event_id: str) -> str: """The stream token for an event @@ -954,7 +948,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): direction: str = "b", limit: int = -1, event_filter: Optional[Filter] = None, - ) -> Tuple[List[_EventDictReturn], str]: + ) -> Tuple[List[_EventDictReturn], RoomStreamToken]: """Returns list of events before or after a given token. Args: @@ -1054,17 +1048,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token - return rows, str(next_token) + return rows, next_token async def paginate_room_events( self, room_id: str, - from_key: str, - to_key: Optional[str] = None, + from_key: RoomStreamToken, + to_key: Optional[RoomStreamToken] = None, direction: str = "b", limit: int = -1, event_filter: Optional[Filter] = None, - ) -> Tuple[List[EventBase], str]: + ) -> Tuple[List[EventBase], RoomStreamToken]: """Returns list of events before or after a given token. Args: @@ -1083,17 +1077,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): and `to_key`). """ - parsed_from_key = RoomStreamToken.parse(from_key) - parsed_to_key = None - if to_key: - parsed_to_key = RoomStreamToken.parse(to_key) - rows, token = await self.db_pool.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, - parsed_from_key, - parsed_to_key, + from_key, + to_key, direction, limit, event_filter, diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index dbaeef91dd..d89f6ed128 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -18,7 +18,7 @@ import itertools import logging from collections import deque, namedtuple -from typing import Iterable, List, Optional, Set, Tuple +from typing import Dict, Iterable, List, Optional, Set, Tuple from prometheus_client import Counter, Histogram @@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState -from synapse.types import StateMap +from synapse.types import Collection, StateMap from synapse.util.async_helpers import ObservableDeferred from synapse.util.metrics import Measure @@ -185,6 +185,8 @@ class EventsPersistenceStorage: # store for now. self.main_store = stores.main self.state_store = stores.state + + assert stores.persist_events self.persist_events_store = stores.persist_events self._clock = hs.get_clock() @@ -208,7 +210,7 @@ class EventsPersistenceStorage: Returns: the stream ordering of the latest persisted event """ - partitioned = {} + partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]] for event, ctx in events_and_contexts: partitioned.setdefault(event.room_id, []).append((event, ctx)) @@ -305,7 +307,9 @@ class EventsPersistenceStorage: # Work out the new "current state" for each room. # We do this by working out what the new extremities are and then # calculating the state from that. - events_by_room = {} + events_by_room = ( + {} + ) # type: Dict[str, List[Tuple[EventBase, EventContext]]] for event, context in chunk: events_by_room.setdefault(event.room_id, []).append( (event, context) @@ -436,7 +440,7 @@ class EventsPersistenceStorage: self, room_id: str, event_contexts: List[Tuple[EventBase, EventContext]], - latest_event_ids: List[str], + latest_event_ids: Collection[str], ): """Calculates the new forward extremities for a room given events to persist. @@ -470,7 +474,7 @@ class EventsPersistenceStorage: # Remove any events which are prev_events of any existing events. existing_prevs = await self.persist_events_store._get_events_which_are_prevs( result - ) + ) # type: Collection[str] result.difference_update(existing_prevs) # Finally handle the case where the new events have soft-failed prev diff --git a/synapse/types.py b/synapse/types.py index ba45335038..dc09448bdc 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -425,7 +425,9 @@ class RoomStreamToken: @attr.s(slots=True, frozen=True) class StreamToken: - room_key = attr.ib(type=str) + room_key = attr.ib( + type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken) + ) presence_key = attr.ib(type=int) typing_key = attr.ib(type=int) receipt_key = attr.ib(type=int) @@ -445,21 +447,16 @@ class StreamToken: while len(keys) < len(attr.fields(cls)): # i.e. old token from before receipt_key keys.append("0") - return cls(keys[0], *(int(k) for k in keys[1:])) + return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:])) except Exception: raise SynapseError(400, "Invalid Token") def to_string(self): - return self._SEPARATOR.join([str(k) for k in attr.astuple(self)]) + return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)]) @property def room_stream_id(self): - # TODO(markjh): Awful hack to work around hacks in the presence tests - # which assume that the keys are integers. - if type(self.room_key) is int: - return self.room_key - else: - return int(self.room_key[1:].split("-")[-1]) + return self.room_key.stream def is_after(self, other): """Does this token contain events that the other doesn't?""" @@ -475,7 +472,7 @@ class StreamToken: or (int(other.groups_key) < int(self.groups_key)) ) - def copy_and_advance(self, key, new_value): + def copy_and_advance(self, key, new_value) -> "StreamToken": """Advance the given key in the token to a new value if and only if the new value is after the old value. """ @@ -491,7 +488,7 @@ class StreamToken: else: return self - def copy_and_replace(self, key, new_value): + def copy_and_replace(self, key, new_value) -> "StreamToken": return attr.evolve(self, **{key: new_value}) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index fb1ca90336..e93aa84405 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -71,7 +71,10 @@ async def inject_event( """ event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) - await hs.get_storage().persistence.persist_event(event, context) + persistence = hs.get_storage().persistence + assert persistence is not None + + await persistence.persist_event(event, context) return event -- cgit 1.5.1 From b82d68c0bd952131836d00994c3c2a79b3d3a267 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 14 Sep 2020 17:07:04 +0300 Subject: Add the topic and avatar to the room details admin API (#8305) --- changelog.d/8305.feature | 1 + docs/admin_api/rooms.md | 4 ++++ synapse/storage/databases/main/room.py | 3 ++- tests/rest/admin/test_room.py | 2 ++ 4 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8305.feature (limited to 'tests') diff --git a/changelog.d/8305.feature b/changelog.d/8305.feature new file mode 100644 index 0000000000..862dfdf959 --- /dev/null +++ b/changelog.d/8305.feature @@ -0,0 +1 @@ +Add the room topic and avatar to the room details admin API. diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 0f267d2b7b..fa9b914fa7 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -275,6 +275,8 @@ The following fields are possible in the JSON response body: * `room_id` - The ID of the room. * `name` - The name of the room. +* `topic` - The topic of the room. +* `avatar` - The `mxc` URI to the avatar of the room. * `canonical_alias` - The canonical (main) alias address of the room. * `joined_members` - How many users are currently in the room. * `joined_local_members` - How many local users are currently in the room. @@ -304,6 +306,8 @@ Response: { "room_id": "!mscvqgqpHYjBGDxNym:matrix.org", "name": "Music Theory", + "avatar": "mxc://matrix.org/AQDaVFlbkQoErdOgqWRgiGSV", + "topic": "Theory, Composition, Notation, Analysis", "canonical_alias": "#musictheory:matrix.org", "joined_members": 127 "joined_local_members": 2, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 717df97301..127588ce4c 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -104,7 +104,8 @@ class RoomWorkerStore(SQLBaseStore): curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, rooms.creator, state.encryption, state.is_federatable AS federatable, rooms.is_public AS public, state.join_rules, state.guest_access, - state.history_visibility, curr.current_state_events AS state_events + state.history_visibility, curr.current_state_events AS state_events, + state.avatar, state.topic FROM rooms LEFT JOIN room_stats_state state USING (room_id) LEFT JOIN room_stats_current curr USING (room_id) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 408c568a27..6dfc709dc5 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1174,6 +1174,8 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("room_id", channel.json_body) self.assertIn("name", channel.json_body) + self.assertIn("topic", channel.json_body) + self.assertIn("avatar", channel.json_body) self.assertIn("canonical_alias", channel.json_body) self.assertIn("joined_members", channel.json_body) self.assertIn("joined_local_members", channel.json_body) -- cgit 1.5.1 From 576bc37d318f866f11f71e34ce7190aa45b74780 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Tue, 15 Sep 2020 09:07:19 +0100 Subject: Catch-up after Federation Outage (split, 4): catch-up loop (#8272) --- changelog.d/8272.bugfix | 1 + synapse/federation/sender/per_destination_queue.py | 129 +++++++++++++++- synapse/storage/databases/main/transactions.py | 43 +++++- tests/federation/test_federation_catch_up.py | 165 +++++++++++++++++++++ tests/handlers/test_typing.py | 5 + 5 files changed, 338 insertions(+), 5 deletions(-) create mode 100644 changelog.d/8272.bugfix (limited to 'tests') diff --git a/changelog.d/8272.bugfix b/changelog.d/8272.bugfix new file mode 100644 index 0000000000..532d0e22fe --- /dev/null +++ b/changelog.d/8272.bugfix @@ -0,0 +1 @@ +Fix messages over federation being lost until an event is sent into the same room. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 9f0852b4a2..2657767fd1 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -15,7 +15,7 @@ # limitations under the License. import datetime import logging -from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple +from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast from prometheus_client import Counter @@ -92,6 +92,21 @@ class PerDestinationQueue: self._destination = destination self.transmission_loop_running = False + # True whilst we are sending events that the remote homeserver missed + # because it was unreachable. We start in this state so we can perform + # catch-up at startup. + # New events will only be sent once this is finished, at which point + # _catching_up is flipped to False. + self._catching_up = True # type: bool + + # The stream_ordering of the most recent PDU that was discarded due to + # being in catch-up mode. + self._catchup_last_skipped = 0 # type: int + + # Cache of the last successfully-transmitted stream ordering for this + # destination (we are the only updater so this is safe) + self._last_successful_stream_ordering = None # type: Optional[int] + # a list of pending PDUs self._pending_pdus = [] # type: List[EventBase] @@ -138,7 +153,13 @@ class PerDestinationQueue: Args: pdu: pdu to send """ - self._pending_pdus.append(pdu) + if not self._catching_up or self._last_successful_stream_ordering is None: + # only enqueue the PDU if we are not catching up (False) or do not + # yet know if we have anything to catch up (None) + self._pending_pdus.append(pdu) + else: + self._catchup_last_skipped = pdu.internal_metadata.stream_ordering + self.attempt_new_transaction() def send_presence(self, states: Iterable[UserPresenceState]) -> None: @@ -218,6 +239,13 @@ class PerDestinationQueue: # hence why we throw the result away. await get_retry_limiter(self._destination, self._clock, self._store) + if self._catching_up: + # we potentially need to catch-up first + await self._catch_up_transmission_loop() + if self._catching_up: + # not caught up yet + return + pending_pdus = [] while True: # We have to keep 2 free slots for presence and rr_edus @@ -351,8 +379,9 @@ class PerDestinationQueue: if e.retry_interval > 60 * 60 * 1000: # we won't retry for another hour! # (this suggests a significant outage) - # We drop pending PDUs and EDUs because otherwise they will + # We drop pending EDUs because otherwise they will # rack up indefinitely. + # (Dropping PDUs is already performed by `_start_catching_up`.) # Note that: # - the EDUs that are being dropped here are those that we can # afford to drop (specifically, only typing notifications, @@ -364,11 +393,12 @@ class PerDestinationQueue: # dropping read receipts is a bit sad but should be solved # through another mechanism, because this is all volatile! - self._pending_pdus = [] self._pending_edus = [] self._pending_edus_keyed = {} self._pending_presence = {} self._pending_rrs = {} + + self._start_catching_up() except FederationDeniedError as e: logger.info(e) except HttpResponseException as e: @@ -378,6 +408,8 @@ class PerDestinationQueue: e.code, e, ) + + self._start_catching_up() except RequestSendFailed as e: logger.warning( "TX [%s] Failed to send transaction: %s", self._destination, e @@ -387,16 +419,96 @@ class PerDestinationQueue: logger.info( "Failed to send event %s to %s", p.event_id, self._destination ) + + self._start_catching_up() except Exception: logger.exception("TX [%s] Failed to send transaction", self._destination) for p in pending_pdus: logger.info( "Failed to send event %s to %s", p.event_id, self._destination ) + + self._start_catching_up() finally: # We want to be *very* sure we clear this after we stop processing self.transmission_loop_running = False + async def _catch_up_transmission_loop(self) -> None: + first_catch_up_check = self._last_successful_stream_ordering is None + + if first_catch_up_check: + # first catchup so get last_successful_stream_ordering from database + self._last_successful_stream_ordering = await self._store.get_destination_last_successful_stream_ordering( + self._destination + ) + + if self._last_successful_stream_ordering is None: + # if it's still None, then this means we don't have the information + # in our database ­ we haven't successfully sent a PDU to this server + # (at least since the introduction of the feature tracking + # last_successful_stream_ordering). + # Sadly, this means we can't do anything here as we don't know what + # needs catching up — so catching up is futile; let's stop. + self._catching_up = False + return + + # get at most 50 catchup room/PDUs + while True: + event_ids = await self._store.get_catch_up_room_event_ids( + self._destination, self._last_successful_stream_ordering, + ) + + if not event_ids: + # No more events to catch up on, but we can't ignore the chance + # of a race condition, so we check that no new events have been + # skipped due to us being in catch-up mode + + if self._catchup_last_skipped > self._last_successful_stream_ordering: + # another event has been skipped because we were in catch-up mode + continue + + # we are done catching up! + self._catching_up = False + break + + if first_catch_up_check: + # as this is our check for needing catch-up, we may have PDUs in + # the queue from before we *knew* we had to do catch-up, so + # clear those out now. + self._start_catching_up() + + # fetch the relevant events from the event store + # - redacted behaviour of REDACT is fine, since we only send metadata + # of redacted events to the destination. + # - don't need to worry about rejected events as we do not actively + # forward received events over federation. + catchup_pdus = await self._store.get_events_as_list(event_ids) + if not catchup_pdus: + raise AssertionError( + "No events retrieved when we asked for %r. " + "This should not happen." % event_ids + ) + + if logger.isEnabledFor(logging.INFO): + rooms = (p.room_id for p in catchup_pdus) + logger.info("Catching up rooms to %s: %r", self._destination, rooms) + + success = await self._transaction_manager.send_new_transaction( + self._destination, catchup_pdus, [] + ) + + if not success: + return + + sent_transactions_counter.inc() + final_pdu = catchup_pdus[-1] + self._last_successful_stream_ordering = cast( + int, final_pdu.internal_metadata.stream_ordering + ) + await self._store.set_destination_last_successful_stream_ordering( + self._destination, self._last_successful_stream_ordering + ) + def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: if not self._pending_rrs: return @@ -457,3 +569,12 @@ class PerDestinationQueue: ] return (edus, stream_id) + + def _start_catching_up(self) -> None: + """ + Marks this destination as being in catch-up mode. + + This throws away the PDU queue. + """ + self._catching_up = True + self._pending_pdus = [] diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index c0a958252e..091367006e 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -15,7 +15,7 @@ import logging from collections import namedtuple -from typing import Iterable, Optional, Tuple +from typing import Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json @@ -371,3 +371,44 @@ class TransactionStore(SQLBaseStore): values={"last_successful_stream_ordering": last_successful_stream_ordering}, desc="set_last_successful_stream_ordering", ) + + async def get_catch_up_room_event_ids( + self, destination: str, last_successful_stream_ordering: int, + ) -> List[str]: + """ + Returns at most 50 event IDs and their corresponding stream_orderings + that correspond to the oldest events that have not yet been sent to + the destination. + + Args: + destination: the destination in question + last_successful_stream_ordering: the stream_ordering of the + most-recently successfully-transmitted event to the destination + + Returns: + list of event_ids + """ + return await self.db_pool.runInteraction( + "get_catch_up_room_event_ids", + self._get_catch_up_room_event_ids_txn, + destination, + last_successful_stream_ordering, + ) + + @staticmethod + def _get_catch_up_room_event_ids_txn( + txn, destination: str, last_successful_stream_ordering: int, + ) -> List[str]: + q = """ + SELECT event_id FROM destination_rooms + JOIN events USING (stream_ordering) + WHERE destination = ? + AND stream_ordering > ? + ORDER BY stream_ordering + LIMIT 50 + """ + txn.execute( + q, (destination, last_successful_stream_ordering), + ) + event_ids = [row[0] for row in txn] + return event_ids diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 6cdcc378f0..cc52c3dfac 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -1,5 +1,10 @@ +from typing import List, Tuple + from mock import Mock +from synapse.events import EventBase +from synapse.federation.sender import PerDestinationQueue, TransactionManager +from synapse.federation.units import Edu from synapse.rest import admin from synapse.rest.client.v1 import login, room @@ -156,3 +161,163 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): row_2["stream_ordering"], "Send succeeded but not marked as last_successful_stream_ordering", ) + + @override_config({"send_federation": True}) # critical to federate + def test_catch_up_from_blank_state(self): + """ + Runs an overall test of federation catch-up from scratch. + Further tests will focus on more narrow aspects and edge-cases, but I + hope to provide an overall view with this test. + """ + # bring the other server online + self.is_online = True + + # let's make some events for the other server to receive + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room_1 = self.helper.create_room_as("u1", tok=u1_token) + room_2 = self.helper.create_room_as("u1", tok=u1_token) + + # also critical to federate + self.get_success( + event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join") + ) + self.get_success( + event_injection.inject_member_event(self.hs, room_2, "@user:host2", "join") + ) + + self.helper.send_state( + room_1, event_type="m.room.topic", body={"topic": "wombat"}, tok=u1_token + ) + + # check: PDU received for topic event + self.assertEqual(len(self.pdus), 1) + self.assertEqual(self.pdus[0]["type"], "m.room.topic") + + # take the remote offline + self.is_online = False + + # send another event + self.helper.send(room_1, "hi user!", tok=u1_token) + + # check: things didn't go well since the remote is down + self.assertEqual(len(self.failed_pdus), 1) + self.assertEqual(self.failed_pdus[0]["content"]["body"], "hi user!") + + # let's delete the federation transmission queue + # (this pretends we are starting up fresh.) + self.assertFalse( + self.hs.get_federation_sender() + ._per_destination_queues["host2"] + .transmission_loop_running + ) + del self.hs.get_federation_sender()._per_destination_queues["host2"] + + # let's also clear any backoffs + self.get_success( + self.hs.get_datastore().set_destination_retry_timings("host2", None, 0, 0) + ) + + # bring the remote online and clear the received pdu list + self.is_online = True + self.pdus = [] + + # now we need to initiate a federation transaction somehow… + # to do that, let's send another event (because it's simple to do) + # (do it to another room otherwise the catch-up logic decides it doesn't + # need to catch up room_1 — something I overlooked when first writing + # this test) + self.helper.send(room_2, "wombats!", tok=u1_token) + + # we should now have received both PDUs + self.assertEqual(len(self.pdus), 2) + self.assertEqual(self.pdus[0]["content"]["body"], "hi user!") + self.assertEqual(self.pdus[1]["content"]["body"], "wombats!") + + def make_fake_destination_queue( + self, destination: str = "host2" + ) -> Tuple[PerDestinationQueue, List[EventBase]]: + """ + Makes a fake per-destination queue. + """ + transaction_manager = TransactionManager(self.hs) + per_dest_queue = PerDestinationQueue(self.hs, transaction_manager, destination) + results_list = [] + + async def fake_send( + destination_tm: str, + pending_pdus: List[EventBase], + _pending_edus: List[Edu], + ) -> bool: + assert destination == destination_tm + results_list.extend(pending_pdus) + return True # success! + + transaction_manager.send_new_transaction = fake_send + + return per_dest_queue, results_list + + @override_config({"send_federation": True}) + def test_catch_up_loop(self): + """ + Tests the behaviour of _catch_up_transmission_loop. + """ + + # ARRANGE: + # - a local user (u1) + # - 3 rooms which u1 is joined to (and remote user @user:host2 is + # joined to) + # - some events (1 to 5) in those rooms + # we have 'already sent' events 1 and 2 to host2 + per_dest_queue, sent_pdus = self.make_fake_destination_queue() + + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room_1 = self.helper.create_room_as("u1", tok=u1_token) + room_2 = self.helper.create_room_as("u1", tok=u1_token) + room_3 = self.helper.create_room_as("u1", tok=u1_token) + self.get_success( + event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join") + ) + self.get_success( + event_injection.inject_member_event(self.hs, room_2, "@user:host2", "join") + ) + self.get_success( + event_injection.inject_member_event(self.hs, room_3, "@user:host2", "join") + ) + + # create some events + self.helper.send(room_1, "you hear me!!", tok=u1_token) + event_id_2 = self.helper.send(room_2, "wombats!", tok=u1_token)["event_id"] + self.helper.send(room_3, "Matrix!", tok=u1_token) + event_id_4 = self.helper.send(room_2, "rabbits!", tok=u1_token)["event_id"] + event_id_5 = self.helper.send(room_3, "Synapse!", tok=u1_token)["event_id"] + + # destination_rooms should already be populated, but let us pretend that we already + # sent (successfully) up to and including event id 2 + event_2 = self.get_success(self.hs.get_datastore().get_event(event_id_2)) + + # also fetch event 5 so we know its last_successful_stream_ordering later + event_5 = self.get_success(self.hs.get_datastore().get_event(event_id_5)) + + self.get_success( + self.hs.get_datastore().set_destination_last_successful_stream_ordering( + "host2", event_2.internal_metadata.stream_ordering + ) + ) + + # ACT + self.get_success(per_dest_queue._catch_up_transmission_loop()) + + # ASSERT, noticing in particular: + # - event 3 not sent out, because event 5 replaces it + # - order is least recent first, so event 5 comes after event 4 + # - catch-up is completed + self.assertEqual(len(sent_pdus), 2) + self.assertEqual(sent_pdus[0].event_id, event_id_4) + self.assertEqual(sent_pdus[1].event_id, event_id_5) + self.assertFalse(per_dest_queue._catching_up) + self.assertEqual( + per_dest_queue._last_successful_stream_ordering, + event_5.internal_metadata.stream_ordering, + ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f306a09bfa..3fec09ea8a 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -73,6 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "delivered_txn", "get_received_txn_response", "set_received_txn_response", + "get_destination_last_successful_stream_ordering", "get_destination_retry_timings", "get_devices_by_remote", "maybe_store_room_on_invite", @@ -121,6 +122,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): (0, []) ) + self.datastore.get_destination_last_successful_stream_ordering.return_value = make_awaitable( + None + ) + def get_received_txn_response(*args): return defer.succeed(None) -- cgit 1.5.1 From 7c407efdc80abf2a991844d107a896d629e3965a Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Fri, 18 Sep 2020 13:56:40 +0200 Subject: Update test logging to be able to accept braces (#8335) --- changelog.d/8335.misc | 1 + tests/test_utils/logging_setup.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 changelog.d/8335.misc (limited to 'tests') diff --git a/changelog.d/8335.misc b/changelog.d/8335.misc new file mode 100644 index 0000000000..7e0a4c7d83 --- /dev/null +++ b/changelog.d/8335.misc @@ -0,0 +1 @@ +Fix test logging to allow braces in log output. \ No newline at end of file diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index 2d96b0fa8d..fdfb840b62 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -29,8 +29,7 @@ class ToTwistedHandler(logging.Handler): log_entry = self.format(record) log_level = record.levelname.lower().replace("warning", "warn") self.tx_log.emit( - twisted.logger.LogLevel.levelWithName(log_level), - log_entry.replace("{", r"(").replace("}", r")"), + twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry ) -- cgit 1.5.1 From 68c7a6936f8921744d083e6dc8a2a085cce30b2a Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Fri, 18 Sep 2020 14:55:13 +0100 Subject: Allow appservice users to /login (#8320) Add ability for ASes to /login using the `uk.half-shot.msc2778.login.application_service` login `type`. Co-authored-by: Patrick Cloke --- changelog.d/8320.feature | 1 + synapse/rest/client/v1/login.py | 49 +++++++++++--- tests/rest/client/v1/test_login.py | 134 ++++++++++++++++++++++++++++++++++++- 3 files changed, 173 insertions(+), 11 deletions(-) create mode 100644 changelog.d/8320.feature (limited to 'tests') diff --git a/changelog.d/8320.feature b/changelog.d/8320.feature new file mode 100644 index 0000000000..475a5fe62d --- /dev/null +++ b/changelog.d/8320.feature @@ -0,0 +1 @@ +Add `uk.half-shot.msc2778.login.application_service` login type to allow appservices to login. diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a14618ac84..dd8cdc0d9f 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -18,6 +18,7 @@ from typing import Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter +from synapse.appservice import ApplicationService from synapse.handlers.auth import ( convert_client_dict_legacy_fields_to_identifier, login_id_phone_to_thirdparty, @@ -44,6 +45,7 @@ class LoginRestServlet(RestServlet): TOKEN_TYPE = "m.login.token" JWT_TYPE = "org.matrix.login.jwt" JWT_TYPE_DEPRECATED = "m.login.jwt" + APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" def __init__(self, hs): super(LoginRestServlet, self).__init__() @@ -61,6 +63,8 @@ class LoginRestServlet(RestServlet): self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled + self.auth = hs.get_auth() + self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() @@ -107,6 +111,8 @@ class LoginRestServlet(RestServlet): ({"type": t} for t in self.auth_handler.get_supported_login_types()) ) + flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) + return 200, {"flows": flows} def on_OPTIONS(self, request: SynapseRequest): @@ -116,8 +122,12 @@ class LoginRestServlet(RestServlet): self._address_ratelimiter.ratelimit(request.getClientIP()) login_submission = parse_json_object_from_request(request) + try: - if self.jwt_enabled and ( + if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: + appservice = self.auth.get_appservice_by_req(request) + result = await self._do_appservice_login(login_submission, appservice) + elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): @@ -134,6 +144,33 @@ class LoginRestServlet(RestServlet): result["well_known"] = well_known_data return 200, result + def _get_qualified_user_id(self, identifier): + if identifier["type"] != "m.id.user": + raise SynapseError(400, "Unknown login identifier type") + if "user" not in identifier: + raise SynapseError(400, "User identifier is missing 'user' key") + + if identifier["user"].startswith("@"): + return identifier["user"] + else: + return UserID(identifier["user"], self.hs.hostname).to_string() + + async def _do_appservice_login( + self, login_submission: JsonDict, appservice: ApplicationService + ): + logger.info( + "Got appservice login request with identifier: %r", + login_submission.get("identifier"), + ) + + identifier = convert_client_dict_legacy_fields_to_identifier(login_submission) + qualified_user_id = self._get_qualified_user_id(identifier) + + if not appservice.is_interested_in_user(qualified_user_id): + raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) + + return await self._complete_login(qualified_user_id, login_submission) + async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]: """Handle non-token/saml/jwt logins @@ -219,15 +256,7 @@ class LoginRestServlet(RestServlet): # by this point, the identifier should be an m.id.user: if it's anything # else, we haven't understood it. - if identifier["type"] != "m.id.user": - raise SynapseError(400, "Unknown login identifier type") - if "user" not in identifier: - raise SynapseError(400, "User identifier is missing 'user' key") - - if identifier["user"].startswith("@"): - qualified_user_id = identifier["user"] - else: - qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string() + qualified_user_id = self._get_qualified_user_id(identifier) # Check if we've hit the failed ratelimit (but don't update it) self._failed_attempts_ratelimiter.ratelimit( diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 2668662c9e..5d987a30c7 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -7,8 +7,9 @@ from mock import Mock import jwt import synapse.rest.admin +from synapse.appservice import ApplicationService from synapse.rest.client.v1 import login, logout -from synapse.rest.client.v2_alpha import devices +from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from tests import unittest @@ -748,3 +749,134 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): channel.json_body["error"], "JWT validation failed: Signature verification failed", ) + + +AS_USER = "as_user_alice" + + +class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + register.register_servlets, + ] + + def register_as_user(self, username): + request, channel = self.make_request( + b"POST", + "/_matrix/client/r0/register?access_token=%s" % (self.service.token,), + {"username": username}, + ) + self.render(request) + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + + self.service = ApplicationService( + id="unique_identifier", + token="some_token", + hostname="example.com", + sender="@asbot:example.com", + namespaces={ + ApplicationService.NS_USERS: [ + {"regex": r"@as_user.*", "exclusive": False} + ], + ApplicationService.NS_ROOMS: [], + ApplicationService.NS_ALIASES: [], + }, + ) + self.another_service = ApplicationService( + id="another__identifier", + token="another_token", + hostname="example.com", + sender="@as2bot:example.com", + namespaces={ + ApplicationService.NS_USERS: [ + {"regex": r"@as2_user.*", "exclusive": False} + ], + ApplicationService.NS_ROOMS: [], + ApplicationService.NS_ALIASES: [], + }, + ) + + self.hs.get_datastore().services_cache.append(self.service) + self.hs.get_datastore().services_cache.append(self.another_service) + return self.hs + + def test_login_appservice_user(self): + """Test that an appservice user can use /login + """ + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": AS_USER}, + } + request, channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.service.token + ) + + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_login_appservice_user_bot(self): + """Test that the appservice bot can use /login + """ + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": self.service.sender}, + } + request, channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.service.token + ) + + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_login_appservice_wrong_user(self): + """Test that non-as users cannot login with the as token + """ + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": "fibble_wibble"}, + } + request, channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.service.token + ) + + self.render(request) + self.assertEquals(channel.result["code"], b"403", channel.result) + + def test_login_appservice_wrong_as(self): + """Test that as users cannot login with wrong as token + """ + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": AS_USER}, + } + request, channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.another_service.token + ) + + self.render(request) + self.assertEquals(channel.result["code"], b"403", channel.result) + + def test_login_appservice_no_token(self): + """Test that users must provide a token when using the appservice + login method + """ + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": AS_USER}, + } + request, channel = self.make_request(b"POST", LOGIN_URL, params) + + self.render(request) + self.assertEquals(channel.result["code"], b"401", channel.result) -- cgit 1.5.1 From 8a4a4186ded34bab1ffb4ee1cebcb476890da207 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Sep 2020 09:56:44 -0400 Subject: Simplify super() calls to Python 3 syntax. (#8344) This converts calls like super(Foo, self) -> super(). Generated with: sed -i "" -Ee 's/super\([^\(]+\)/super()/g' **/*.py --- changelog.d/8344.misc | 1 + scripts-dev/definitions.py | 2 +- scripts-dev/federation_client.py | 2 +- synapse/api/errors.py | 50 ++++++++++------------ synapse/api/filtering.py | 2 +- synapse/app/generic_worker.py | 6 +-- synapse/appservice/api.py | 2 +- synapse/config/consent_config.py | 2 +- synapse/config/registration.py | 2 +- synapse/config/server_notices_config.py | 2 +- synapse/crypto/keyring.py | 4 +- synapse/federation/federation_client.py | 2 +- synapse/federation/federation_server.py | 2 +- synapse/federation/transport/server.py | 10 ++--- synapse/groups/groups_server.py | 2 +- synapse/handlers/admin.py | 2 +- synapse/handlers/auth.py | 2 +- synapse/handlers/deactivate_account.py | 2 +- synapse/handlers/device.py | 4 +- synapse/handlers/directory.py | 2 +- synapse/handlers/events.py | 4 +- synapse/handlers/federation.py | 2 +- synapse/handlers/groups_local.py | 2 +- synapse/handlers/identity.py | 2 +- synapse/handlers/initial_sync.py | 2 +- synapse/handlers/profile.py | 4 +- synapse/handlers/read_marker.py | 2 +- synapse/handlers/receipts.py | 2 +- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 2 +- synapse/handlers/room_list.py | 2 +- synapse/handlers/room_member_worker.py | 2 +- synapse/handlers/search.py | 2 +- synapse/handlers/set_password.py | 2 +- synapse/handlers/user_directory.py | 2 +- synapse/http/__init__.py | 2 +- synapse/logging/formatter.py | 2 +- synapse/logging/scopecontextmanager.py | 6 +-- synapse/push/__init__.py | 2 +- synapse/replication/http/devices.py | 2 +- synapse/replication/http/federation.py | 8 ++-- synapse/replication/http/login.py | 2 +- synapse/replication/http/membership.py | 6 +-- synapse/replication/http/register.py | 4 +- synapse/replication/http/send_event.py | 2 +- synapse/replication/slave/storage/_base.py | 2 +- synapse/replication/slave/storage/account_data.py | 2 +- synapse/replication/slave/storage/client_ips.py | 2 +- synapse/replication/slave/storage/deviceinbox.py | 2 +- synapse/replication/slave/storage/devices.py | 2 +- synapse/replication/slave/storage/events.py | 2 +- synapse/replication/slave/storage/filtering.py | 2 +- synapse/replication/slave/storage/groups.py | 2 +- synapse/replication/slave/storage/presence.py | 2 +- synapse/replication/slave/storage/pushers.py | 2 +- synapse/replication/slave/storage/receipts.py | 2 +- synapse/replication/slave/storage/room.py | 2 +- synapse/replication/tcp/streams/_base.py | 2 +- synapse/rest/admin/devices.py | 2 +- synapse/rest/client/v1/directory.py | 6 +-- synapse/rest/client/v1/events.py | 4 +- synapse/rest/client/v1/initial_sync.py | 2 +- synapse/rest/client/v1/login.py | 4 +- synapse/rest/client/v1/logout.py | 4 +- synapse/rest/client/v1/presence.py | 2 +- synapse/rest/client/v1/profile.py | 6 +-- synapse/rest/client/v1/push_rule.py | 2 +- synapse/rest/client/v1/pusher.py | 6 +-- synapse/rest/client/v1/room.py | 38 ++++++++-------- synapse/rest/client/v1/voip.py | 2 +- synapse/rest/client/v2_alpha/account.py | 22 +++++----- synapse/rest/client/v2_alpha/account_data.py | 4 +- synapse/rest/client/v2_alpha/account_validity.py | 4 +- synapse/rest/client/v2_alpha/auth.py | 2 +- synapse/rest/client/v2_alpha/capabilities.py | 2 +- synapse/rest/client/v2_alpha/devices.py | 6 +-- synapse/rest/client/v2_alpha/filter.py | 4 +- synapse/rest/client/v2_alpha/groups.py | 48 ++++++++++----------- synapse/rest/client/v2_alpha/keys.py | 12 +++--- synapse/rest/client/v2_alpha/notifications.py | 2 +- synapse/rest/client/v2_alpha/openid.py | 2 +- synapse/rest/client/v2_alpha/password_policy.py | 2 +- synapse/rest/client/v2_alpha/read_marker.py | 2 +- synapse/rest/client/v2_alpha/receipts.py | 2 +- synapse/rest/client/v2_alpha/register.py | 10 ++--- synapse/rest/client/v2_alpha/relations.py | 8 ++-- synapse/rest/client/v2_alpha/report_event.py | 2 +- synapse/rest/client/v2_alpha/room_keys.py | 6 +-- .../client/v2_alpha/room_upgrade_rest_servlet.py | 2 +- synapse/rest/client/v2_alpha/sendtodevice.py | 2 +- synapse/rest/client/v2_alpha/shared_rooms.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 2 +- synapse/rest/client/v2_alpha/tags.py | 4 +- synapse/rest/client/v2_alpha/thirdparty.py | 8 ++-- synapse/rest/client/v2_alpha/tokenrefresh.py | 2 +- synapse/rest/client/v2_alpha/user_directory.py | 2 +- synapse/rest/client/versions.py | 2 +- synapse/storage/databases/main/__init__.py | 2 +- synapse/storage/databases/main/account_data.py | 4 +- synapse/storage/databases/main/appservice.py | 2 +- synapse/storage/databases/main/client_ips.py | 4 +- synapse/storage/databases/main/deviceinbox.py | 4 +- synapse/storage/databases/main/devices.py | 4 +- synapse/storage/databases/main/event_federation.py | 2 +- .../storage/databases/main/event_push_actions.py | 4 +- .../storage/databases/main/events_bg_updates.py | 2 +- synapse/storage/databases/main/events_worker.py | 2 +- synapse/storage/databases/main/media_repository.py | 6 +-- .../storage/databases/main/monthly_active_users.py | 4 +- synapse/storage/databases/main/push_rule.py | 2 +- synapse/storage/databases/main/receipts.py | 4 +- synapse/storage/databases/main/registration.py | 6 +-- synapse/storage/databases/main/room.py | 6 +-- synapse/storage/databases/main/roommember.py | 6 +-- synapse/storage/databases/main/search.py | 4 +- synapse/storage/databases/main/state.py | 6 +-- synapse/storage/databases/main/stats.py | 2 +- synapse/storage/databases/main/stream.py | 2 +- synapse/storage/databases/main/transactions.py | 2 +- synapse/storage/databases/main/user_directory.py | 4 +- synapse/storage/databases/state/bg_updates.py | 2 +- synapse/storage/databases/state/store.py | 2 +- synapse/util/manhole.py | 2 +- synapse/util/retryutils.py | 2 +- tests/handlers/test_e2e_keys.py | 2 +- tests/handlers/test_e2e_room_keys.py | 2 +- tests/replication/slave/storage/test_events.py | 2 +- tests/rest/test_well_known.py | 2 +- tests/server.py | 2 +- tests/storage/test_appservice.py | 2 +- tests/storage/test_devices.py | 2 +- tests/test_state.py | 2 +- tests/unittest.py | 2 +- 133 files changed, 272 insertions(+), 281 deletions(-) create mode 100644 changelog.d/8344.misc (limited to 'tests') diff --git a/changelog.d/8344.misc b/changelog.d/8344.misc new file mode 100644 index 0000000000..0b342d5137 --- /dev/null +++ b/changelog.d/8344.misc @@ -0,0 +1 @@ +Simplify `super()` calls to Python 3 syntax. diff --git a/scripts-dev/definitions.py b/scripts-dev/definitions.py index 15e6ce6e16..313860df13 100755 --- a/scripts-dev/definitions.py +++ b/scripts-dev/definitions.py @@ -11,7 +11,7 @@ import yaml class DefinitionVisitor(ast.NodeVisitor): def __init__(self): - super(DefinitionVisitor, self).__init__() + super().__init__() self.functions = {} self.classes = {} self.names = {} diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index 848a826f17..abcec48c4f 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -321,7 +321,7 @@ class MatrixConnectionAdapter(HTTPAdapter): url = urlparse.urlunparse( ("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment) ) - return super(MatrixConnectionAdapter, self).get_connection(url, proxies) + return super().get_connection(url, proxies) if __name__ == "__main__": diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 94a9e58eae..cd6670d0a2 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -87,7 +87,7 @@ class CodeMessageException(RuntimeError): """ def __init__(self, code: Union[int, HTTPStatus], msg: str): - super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) + super().__init__("%d: %s" % (code, msg)) # Some calls to this method pass instances of http.HTTPStatus for `code`. # While HTTPStatus is a subclass of int, it has magic __str__ methods @@ -138,7 +138,7 @@ class SynapseError(CodeMessageException): msg: The human-readable error message. errcode: The matrix error code e.g 'M_FORBIDDEN' """ - super(SynapseError, self).__init__(code, msg) + super().__init__(code, msg) self.errcode = errcode def error_dict(self): @@ -159,7 +159,7 @@ class ProxiedRequestError(SynapseError): errcode: str = Codes.UNKNOWN, additional_fields: Optional[Dict] = None, ): - super(ProxiedRequestError, self).__init__(code, msg, errcode) + super().__init__(code, msg, errcode) if additional_fields is None: self._additional_fields = {} # type: Dict else: @@ -181,7 +181,7 @@ class ConsentNotGivenError(SynapseError): msg: The human-readable error message consent_url: The URL where the user can give their consent """ - super(ConsentNotGivenError, self).__init__( + super().__init__( code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN ) self._consent_uri = consent_uri @@ -201,7 +201,7 @@ class UserDeactivatedError(SynapseError): Args: msg: The human-readable error message """ - super(UserDeactivatedError, self).__init__( + super().__init__( code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED ) @@ -225,7 +225,7 @@ class FederationDeniedError(SynapseError): self.destination = destination - super(FederationDeniedError, self).__init__( + super().__init__( code=403, msg="Federation denied with %s." % (self.destination,), errcode=Codes.FORBIDDEN, @@ -244,9 +244,7 @@ class InteractiveAuthIncompleteError(Exception): """ def __init__(self, session_id: str, result: "JsonDict"): - super(InteractiveAuthIncompleteError, self).__init__( - "Interactive auth not yet complete" - ) + super().__init__("Interactive auth not yet complete") self.session_id = session_id self.result = result @@ -261,14 +259,14 @@ class UnrecognizedRequestError(SynapseError): message = "Unrecognized request" else: message = args[0] - super(UnrecognizedRequestError, self).__init__(400, message, **kwargs) + super().__init__(400, message, **kwargs) class NotFoundError(SynapseError): """An error indicating we can't find the thing you asked for""" def __init__(self, msg: str = "Not found", errcode: str = Codes.NOT_FOUND): - super(NotFoundError, self).__init__(404, msg, errcode=errcode) + super().__init__(404, msg, errcode=errcode) class AuthError(SynapseError): @@ -279,7 +277,7 @@ class AuthError(SynapseError): def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.FORBIDDEN - super(AuthError, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) class InvalidClientCredentialsError(SynapseError): @@ -335,7 +333,7 @@ class ResourceLimitError(SynapseError): ): self.admin_contact = admin_contact self.limit_type = limit_type - super(ResourceLimitError, self).__init__(code, msg, errcode=errcode) + super().__init__(code, msg, errcode=errcode) def error_dict(self): return cs_error( @@ -352,7 +350,7 @@ class EventSizeError(SynapseError): def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.TOO_LARGE - super(EventSizeError, self).__init__(413, *args, **kwargs) + super().__init__(413, *args, **kwargs) class EventStreamError(SynapseError): @@ -361,7 +359,7 @@ class EventStreamError(SynapseError): def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.BAD_PAGINATION - super(EventStreamError, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) class LoginError(SynapseError): @@ -384,7 +382,7 @@ class InvalidCaptchaError(SynapseError): error_url: Optional[str] = None, errcode: str = Codes.CAPTCHA_INVALID, ): - super(InvalidCaptchaError, self).__init__(code, msg, errcode) + super().__init__(code, msg, errcode) self.error_url = error_url def error_dict(self): @@ -402,7 +400,7 @@ class LimitExceededError(SynapseError): retry_after_ms: Optional[int] = None, errcode: str = Codes.LIMIT_EXCEEDED, ): - super(LimitExceededError, self).__init__(code, msg, errcode) + super().__init__(code, msg, errcode) self.retry_after_ms = retry_after_ms def error_dict(self): @@ -418,9 +416,7 @@ class RoomKeysVersionError(SynapseError): Args: current_version: the current version of the store they should have used """ - super(RoomKeysVersionError, self).__init__( - 403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION - ) + super().__init__(403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION) self.current_version = current_version @@ -429,7 +425,7 @@ class UnsupportedRoomVersionError(SynapseError): not support.""" def __init__(self, msg: str = "Homeserver does not support this room version"): - super(UnsupportedRoomVersionError, self).__init__( + super().__init__( code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION, ) @@ -440,7 +436,7 @@ class ThreepidValidationError(SynapseError): def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.FORBIDDEN - super(ThreepidValidationError, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) class IncompatibleRoomVersionError(SynapseError): @@ -451,7 +447,7 @@ class IncompatibleRoomVersionError(SynapseError): """ def __init__(self, room_version: str): - super(IncompatibleRoomVersionError, self).__init__( + super().__init__( code=400, msg="Your homeserver does not support the features required to " "join this room", @@ -473,7 +469,7 @@ class PasswordRefusedError(SynapseError): msg: str = "This password doesn't comply with the server's policy", errcode: str = Codes.WEAK_PASSWORD, ): - super(PasswordRefusedError, self).__init__( + super().__init__( code=400, msg=msg, errcode=errcode, ) @@ -488,7 +484,7 @@ class RequestSendFailed(RuntimeError): """ def __init__(self, inner_exception, can_retry): - super(RequestSendFailed, self).__init__( + super().__init__( "Failed to send request: %s: %s" % (type(inner_exception).__name__, inner_exception) ) @@ -542,7 +538,7 @@ class FederationError(RuntimeError): self.source = source msg = "%s %s: %s" % (level, code, reason) - super(FederationError, self).__init__(msg) + super().__init__(msg) def get_dict(self): return { @@ -570,7 +566,7 @@ class HttpResponseException(CodeMessageException): msg: reason phrase from HTTP response status line response: body of response """ - super(HttpResponseException, self).__init__(code, msg) + super().__init__(code, msg) self.response = response def to_synapse_error(self): diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index bb33345be6..5caf336fd0 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -132,7 +132,7 @@ def matrix_user_id_validator(user_id_str): class Filtering: def __init__(self, hs): - super(Filtering, self).__init__() + super().__init__() self.store = hs.get_datastore() async def get_user_filter(self, user_localpart, filter_id): diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index f985810e88..c38413c893 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -152,7 +152,7 @@ class PresenceStatusStubServlet(RestServlet): PATTERNS = client_patterns("/presence/(?P[^/]*)/status") def __init__(self, hs): - super(PresenceStatusStubServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() async def on_GET(self, request, user_id): @@ -176,7 +176,7 @@ class KeyUploadServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(KeyUploadServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.http_client = hs.get_simple_http_client() @@ -646,7 +646,7 @@ class GenericWorkerServer(HomeServer): class GenericWorkerReplicationHandler(ReplicationDataHandler): def __init__(self, hs): - super(GenericWorkerReplicationHandler, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index bb6fa8299a..1514c0f691 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -88,7 +88,7 @@ class ApplicationServiceApi(SimpleHttpClient): """ def __init__(self, hs): - super(ApplicationServiceApi, self).__init__(hs) + super().__init__(hs) self.clock = hs.get_clock() self.protocol_meta_cache = ResponseCache( diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py index aec9c4bbce..fbddebeeab 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent_config.py @@ -77,7 +77,7 @@ class ConsentConfig(Config): section = "consent" def __init__(self, *args): - super(ConsentConfig, self).__init__(*args) + super().__init__(*args) self.user_consent_version = None self.user_consent_template_dir = None diff --git a/synapse/config/registration.py b/synapse/config/registration.py index a185655774..5ffbb934fe 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -30,7 +30,7 @@ class AccountValidityConfig(Config): def __init__(self, config, synapse_config): if config is None: return - super(AccountValidityConfig, self).__init__() + super().__init__() self.enabled = config.get("enabled", False) self.renew_by_email_enabled = "renew_at" in config diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py index 6c427b6f92..57f69dc8e2 100644 --- a/synapse/config/server_notices_config.py +++ b/synapse/config/server_notices_config.py @@ -62,7 +62,7 @@ class ServerNoticesConfig(Config): section = "servernotices" def __init__(self, *args): - super(ServerNoticesConfig, self).__init__(*args) + super().__init__(*args) self.server_notices_mxid = None self.server_notices_mxid_display_name = None self.server_notices_mxid_avatar_url = None diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 32c31b1cd1..42e4087a92 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -558,7 +558,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): """KeyFetcher impl which fetches keys from the "perspectives" servers""" def __init__(self, hs): - super(PerspectivesKeyFetcher, self).__init__(hs) + super().__init__(hs) self.clock = hs.get_clock() self.client = hs.get_http_client() self.key_servers = self.config.key_servers @@ -728,7 +728,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): """KeyFetcher impl which fetches keys from the origin servers""" def __init__(self, hs): - super(ServerKeyFetcher, self).__init__(hs) + super().__init__(hs) self.clock = hs.get_clock() self.client = hs.get_http_client() diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index a2e8d96ea2..639d19f696 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -79,7 +79,7 @@ class InvalidResponseError(RuntimeError): class FederationClient(FederationBase): def __init__(self, hs): - super(FederationClient, self).__init__(hs) + super().__init__(hs) self.pdu_destination_tried = {} self._clock.looping_call(self._clear_tried_cache, 60 * 1000) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index ff00f0b302..2dcd081cbc 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -90,7 +90,7 @@ pdu_process_time = Histogram( class FederationServer(FederationBase): def __init__(self, hs): - super(FederationServer, self).__init__(hs) + super().__init__(hs) self.auth = hs.get_auth() self.handler = hs.get_handlers().federation_handler diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index cc7e9a973b..3a6b95631e 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -68,7 +68,7 @@ class TransportLayerServer(JsonResource): self.clock = hs.get_clock() self.servlet_groups = servlet_groups - super(TransportLayerServer, self).__init__(hs, canonical_json=False) + super().__init__(hs, canonical_json=False) self.authenticator = Authenticator(hs) self.ratelimiter = hs.get_federation_ratelimiter() @@ -376,9 +376,7 @@ class FederationSendServlet(BaseFederationServlet): RATELIMIT = False def __init__(self, handler, server_name, **kwargs): - super(FederationSendServlet, self).__init__( - handler, server_name=server_name, **kwargs - ) + super().__init__(handler, server_name=server_name, **kwargs) self.server_name = server_name # This is when someone is trying to send us a bunch of data. @@ -773,9 +771,7 @@ class PublicRoomList(BaseFederationServlet): PATH = "/publicRooms" def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access): - super(PublicRoomList, self).__init__( - handler, authenticator, ratelimiter, server_name - ) + super().__init__(handler, authenticator, ratelimiter, server_name) self.allow_access = allow_access async def on_GET(self, origin, content, query): diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 1dd20ee4e1..e5f85b472d 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -336,7 +336,7 @@ class GroupsServerWorkerHandler: class GroupsServerHandler(GroupsServerWorkerHandler): def __init__(self, hs): - super(GroupsServerHandler, self).__init__(hs) + super().__init__(hs) # Ensure attestations get renewed hs.get_groups_attestation_renewer() diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 5e5a64037d..dd981c597e 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class AdminHandler(BaseHandler): def __init__(self, hs): - super(AdminHandler, self).__init__(hs) + super().__init__(hs) self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 4e658d9a48..0322b60cfc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -145,7 +145,7 @@ class AuthHandler(BaseHandler): Args: hs (synapse.server.HomeServer): """ - super(AuthHandler, self).__init__(hs) + super().__init__(hs) self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 25169157c1..0635ad5708 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -29,7 +29,7 @@ class DeactivateAccountHandler(BaseHandler): """Handler which deals with deactivating user accounts.""" def __init__(self, hs): - super(DeactivateAccountHandler, self).__init__(hs) + super().__init__(hs) self.hs = hs self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 4b0a4f96cc..55a9787439 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -48,7 +48,7 @@ MAX_DEVICE_DISPLAY_NAME_LEN = 100 class DeviceWorkerHandler(BaseHandler): def __init__(self, hs): - super(DeviceWorkerHandler, self).__init__(hs) + super().__init__(hs) self.hs = hs self.state = hs.get_state_handler() @@ -251,7 +251,7 @@ class DeviceWorkerHandler(BaseHandler): class DeviceHandler(DeviceWorkerHandler): def __init__(self, hs): - super(DeviceHandler, self).__init__(hs) + super().__init__(hs) self.federation_sender = hs.get_federation_sender() diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 46826eb784..62aa9a2da8 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) class DirectoryHandler(BaseHandler): def __init__(self, hs): - super(DirectoryHandler, self).__init__(hs) + super().__init__(hs) self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index fdce54c5c3..0875b74ea8 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) class EventStreamHandler(BaseHandler): def __init__(self, hs: "HomeServer"): - super(EventStreamHandler, self).__init__(hs) + super().__init__(hs) self.clock = hs.get_clock() @@ -142,7 +142,7 @@ class EventStreamHandler(BaseHandler): class EventHandler(BaseHandler): def __init__(self, hs: "HomeServer"): - super(EventHandler, self).__init__(hs) + super().__init__(hs) self.storage = hs.get_storage() async def get_event( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 262901363f..96eeff7b1b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -115,7 +115,7 @@ class FederationHandler(BaseHandler): """ def __init__(self, hs): - super(FederationHandler, self).__init__(hs) + super().__init__(hs) self.hs = hs diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 44df567983..9684e60fc8 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -240,7 +240,7 @@ class GroupsLocalWorkerHandler: class GroupsLocalHandler(GroupsLocalWorkerHandler): def __init__(self, hs): - super(GroupsLocalHandler, self).__init__(hs) + super().__init__(hs) # Ensure attestations get renewed hs.get_groups_attestation_renewer() diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 0ce6ddfbe4..ab15570f7a 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -45,7 +45,7 @@ id_server_scheme = "https://" class IdentityHandler(BaseHandler): def __init__(self, hs): - super(IdentityHandler, self).__init__(hs) + super().__init__(hs) self.http_client = SimpleHttpClient(hs) # We create a blacklisting instance of SimpleHttpClient for contacting identity diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index ba4828c713..8cd7eb22a3 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -42,7 +42,7 @@ logger = logging.getLogger(__name__) class InitialSyncHandler(BaseHandler): def __init__(self, hs: "HomeServer"): - super(InitialSyncHandler, self).__init__(hs) + super().__init__(hs) self.hs = hs self.state = hs.get_state_handler() self.clock = hs.get_clock() diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 0cb8fad89a..5453e6dfc8 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -44,7 +44,7 @@ class BaseProfileHandler(BaseHandler): """ def __init__(self, hs): - super(BaseProfileHandler, self).__init__(hs) + super().__init__(hs) self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -369,7 +369,7 @@ class MasterProfileHandler(BaseProfileHandler): PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 def __init__(self, hs): - super(MasterProfileHandler, self).__init__(hs) + super().__init__(hs) assert hs.config.worker_app is None diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index e3b528d271..c32f314a1c 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) class ReadMarkerHandler(BaseHandler): def __init__(self, hs): - super(ReadMarkerHandler, self).__init__(hs) + super().__init__(hs) self.server_name = hs.config.server_name self.store = hs.get_datastore() self.read_marker_linearizer = Linearizer(name="read_marker") diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index bdd8e52edd..7225923757 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) class ReceiptsHandler(BaseHandler): def __init__(self, hs): - super(ReceiptsHandler, self).__init__(hs) + super().__init__(hs) self.server_name = hs.config.server_name self.store = hs.get_datastore() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index cde2dbca92..538f4b2a61 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -42,7 +42,7 @@ class RegistrationHandler(BaseHandler): Args: hs (synapse.server.HomeServer): """ - super(RegistrationHandler, self).__init__(hs) + super().__init__(hs) self.hs = hs self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index eeade6ad3f..11bf146bed 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -70,7 +70,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000 class RoomCreationHandler(BaseHandler): def __init__(self, hs: "HomeServer"): - super(RoomCreationHandler, self).__init__(hs) + super().__init__(hs) self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 5dd7b28391..4a13c8e912 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -38,7 +38,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler(BaseHandler): def __init__(self, hs): - super(RoomListHandler, self).__init__(hs) + super().__init__(hs) self.enable_room_list_search = hs.config.enable_room_list_search self.response_cache = ResponseCache(hs, "room_list") self.remote_response_cache = ResponseCache( diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index e7f34737c6..f2e88f6a5b 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class RoomMemberWorkerHandler(RoomMemberHandler): def __init__(self, hs): - super(RoomMemberWorkerHandler, self).__init__(hs) + super().__init__(hs) self._remote_join_client = ReplRemoteJoin.make_client(hs) self._remote_reject_client = ReplRejectInvite.make_client(hs) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index d58f9788c5..6a76c20d79 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) class SearchHandler(BaseHandler): def __init__(self, hs): - super(SearchHandler, self).__init__(hs) + super().__init__(hs) self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 4d245b618b..a5d67f828f 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -27,7 +27,7 @@ class SetPasswordHandler(BaseHandler): """Handler which deals with changing user account passwords""" def __init__(self, hs): - super(SetPasswordHandler, self).__init__(hs) + super().__init__(hs) self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() self._password_policy_handler = hs.get_password_policy_handler() diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index e21f8dbc58..79393c8829 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -37,7 +37,7 @@ class UserDirectoryHandler(StateDeltasHandler): """ def __init__(self, hs): - super(UserDirectoryHandler, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.state = hs.get_state_handler() diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 3880ce0d94..8eb3638591 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -27,7 +27,7 @@ class RequestTimedOutError(SynapseError): """Exception representing timeout of an outbound request""" def __init__(self): - super(RequestTimedOutError, self).__init__(504, "Timed out") + super().__init__(504, "Timed out") def cancelled_to_request_timed_out_error(value, timeout): diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py index d736ad5b9b..11f60a77f7 100644 --- a/synapse/logging/formatter.py +++ b/synapse/logging/formatter.py @@ -30,7 +30,7 @@ class LogFormatter(logging.Formatter): """ def __init__(self, *args, **kwargs): - super(LogFormatter, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) def formatException(self, ei): sio = StringIO() diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py index 026854b4c7..7b9c657456 100644 --- a/synapse/logging/scopecontextmanager.py +++ b/synapse/logging/scopecontextmanager.py @@ -107,7 +107,7 @@ class _LogContextScope(Scope): finish_on_close (Boolean): if True finish the span when the scope is closed """ - super(_LogContextScope, self).__init__(manager, span) + super().__init__(manager, span) self.logcontext = logcontext self._finish_on_close = finish_on_close self._enter_logcontext = enter_logcontext @@ -120,9 +120,9 @@ class _LogContextScope(Scope): def __exit__(self, type, value, traceback): if type == twisted.internet.defer._DefGen_Return: - super(_LogContextScope, self).__exit__(None, None, None) + super().__exit__(None, None, None) else: - super(_LogContextScope, self).__exit__(type, value, traceback) + super().__exit__(type, value, traceback) if self._enter_logcontext: self.logcontext.__exit__(type, value, traceback) else: # the logcontext existed before the creation of the scope diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index edf45dc599..5a437f9810 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -16,4 +16,4 @@ class PusherConfigException(Exception): def __init__(self, msg): - super(PusherConfigException, self).__init__(msg) + super().__init__(msg) diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 20f3ba76c0..807b85d2e1 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -53,7 +53,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): CACHE = False def __init__(self, hs): - super(ReplicationUserDevicesResyncRestServlet, self).__init__(hs) + super().__init__(hs) self.device_list_updater = hs.get_device_handler().device_list_updater self.store = hs.get_datastore() diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 5c8be747e1..5393b9a9e7 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -57,7 +57,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): PATH_ARGS = () def __init__(self, hs): - super(ReplicationFederationSendEventsRestServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.storage = hs.get_storage() @@ -150,7 +150,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): PATH_ARGS = ("edu_type",) def __init__(self, hs): - super(ReplicationFederationSendEduRestServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -193,7 +193,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): CACHE = False def __init__(self, hs): - super(ReplicationGetQueryRestServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -236,7 +236,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): PATH_ARGS = ("room_id",) def __init__(self, hs): - super(ReplicationCleanRoomRestServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index fb326bb869..4c81e2d784 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -32,7 +32,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): PATH_ARGS = ("user_id",) def __init__(self, hs): - super(RegisterDeviceReplicationServlet, self).__init__(hs) + super().__init__(hs) self.registration_handler = hs.get_registration_handler() @staticmethod diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 08095fdf7d..30680baee8 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -45,7 +45,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): PATH_ARGS = ("room_id", "user_id") def __init__(self, hs): - super(ReplicationRemoteJoinRestServlet, self).__init__(hs) + super().__init__(hs) self.federation_handler = hs.get_handlers().federation_handler self.store = hs.get_datastore() @@ -107,7 +107,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): PATH_ARGS = ("invite_event_id",) def __init__(self, hs: "HomeServer"): - super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -168,7 +168,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): CACHE = False # No point caching as should return instantly. def __init__(self, hs): - super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__(hs) + super().__init__(hs) self.registeration_handler = hs.get_registration_handler() self.store = hs.get_datastore() diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index a02b27474d..7b12ec9060 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -29,7 +29,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): PATH_ARGS = ("user_id",) def __init__(self, hs): - super(ReplicationRegisterServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.registration_handler = hs.get_registration_handler() @@ -104,7 +104,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): PATH_ARGS = ("user_id",) def __init__(self, hs): - super(ReplicationPostRegisterActionsServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.registration_handler = hs.get_registration_handler() diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index f13d452426..9a3a694d5d 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -52,7 +52,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): PATH_ARGS = ("event_id",) def __init__(self, hs): - super(ReplicationSendEventRestServlet, self).__init__(hs) + super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastore() diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 60f2e1245f..d25fa49e1a 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class BaseSlavedStore(CacheInvalidationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(BaseSlavedStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = MultiWriterIdGenerator( db_conn, diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index bb66ba9b80..4268565fc8 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -34,7 +34,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved ], ) - super(SlavedAccountDataStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index a6fdedde63..1f8dafe7ea 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -22,7 +22,7 @@ from ._base import BaseSlavedStore class SlavedClientIpStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedClientIpStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 533d927701..5b045bed02 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -24,7 +24,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( db_conn, "device_inbox", "stream_id" ) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 3b788c9625..e0d86240dd 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -24,7 +24,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedDeviceStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index da1cc836cf..fbffe6d85c 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -56,7 +56,7 @@ class SlavedEventStore( BaseSlavedStore, ): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedEventStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 2562b6fc38..6a23252861 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -21,7 +21,7 @@ from ._base import BaseSlavedStore class SlavedFilteringStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedFilteringStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired get_user_filter = FilteringStore.__dict__["get_user_filter"] diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 567b4a5cc1..30955bcbfe 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -23,7 +23,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 025f6f6be8..55620c03d8 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -25,7 +25,7 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedPresenceStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_on_startup = self._get_active_presence(db_conn) # type: ignore diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 9da218bfe8..c418730ba8 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -24,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedPusherStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._pushers_id_gen = SlavedIdTracker( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 5c2986e050..6195917376 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -30,7 +30,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): db_conn, "receipts_linearized", "stream_id" ) - super(SlavedReceiptsStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 80ae803ad9..109ac6bea1 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -23,7 +23,7 @@ from ._slaved_id_tracker import SlavedIdTracker class RoomStore(RoomWorkerStore, BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" ) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 1f609f158c..54dccd15a6 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -345,7 +345,7 @@ class PushRulesStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() - super(PushRulesStream, self).__init__( + super().__init__( hs.get_instance_name(), self._current_token, self.store.get_all_push_rule_updates, diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index 4670d7160d..a163863322 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -36,7 +36,7 @@ class DeviceRestServlet(RestServlet): ) def __init__(self, hs): - super(DeviceRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index b210015173..faabeeb91c 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -40,7 +40,7 @@ class ClientDirectoryServer(RestServlet): PATTERNS = client_patterns("/directory/room/(?P[^/]*)$", v1=True) def __init__(self, hs): - super(ClientDirectoryServer, self).__init__() + super().__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() self.auth = hs.get_auth() @@ -120,7 +120,7 @@ class ClientDirectoryListServer(RestServlet): PATTERNS = client_patterns("/directory/list/room/(?P[^/]*)$", v1=True) def __init__(self, hs): - super(ClientDirectoryListServer, self).__init__() + super().__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() self.auth = hs.get_auth() @@ -160,7 +160,7 @@ class ClientAppserviceDirectoryListServer(RestServlet): ) def __init__(self, hs): - super(ClientAppserviceDirectoryListServer, self).__init__() + super().__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 25effd0261..985d994f6b 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -30,7 +30,7 @@ class EventStreamRestServlet(RestServlet): DEFAULT_LONGPOLL_TIME_MS = 30000 def __init__(self, hs): - super(EventStreamRestServlet, self).__init__() + super().__init__() self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() @@ -74,7 +74,7 @@ class EventRestServlet(RestServlet): PATTERNS = client_patterns("/events/(?P[^/]*)$", v1=True) def __init__(self, hs): - super(EventRestServlet, self).__init__() + super().__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 910b3b4eeb..d7042786ce 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -24,7 +24,7 @@ class InitialSyncRestServlet(RestServlet): PATTERNS = client_patterns("/initialSync$", v1=True) def __init__(self, hs): - super(InitialSyncRestServlet, self).__init__() + super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index dd8cdc0d9f..250b03a025 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -48,7 +48,7 @@ class LoginRestServlet(RestServlet): APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" def __init__(self, hs): - super(LoginRestServlet, self).__init__() + super().__init__() self.hs = hs # JWT configuration variables. @@ -429,7 +429,7 @@ class CasTicketServlet(RestServlet): PATTERNS = client_patterns("/login/cas/ticket", v1=True) def __init__(self, hs): - super(CasTicketServlet, self).__init__() + super().__init__() self._cas_handler = hs.get_cas_handler() async def on_GET(self, request: SynapseRequest) -> None: diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index b0c30b65be..f792b50cdc 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -25,7 +25,7 @@ class LogoutRestServlet(RestServlet): PATTERNS = client_patterns("/logout$", v1=True) def __init__(self, hs): - super(LogoutRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() @@ -53,7 +53,7 @@ class LogoutAllRestServlet(RestServlet): PATTERNS = client_patterns("/logout/all$", v1=True) def __init__(self, hs): - super(LogoutAllRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 970fdd5834..79d8e3057f 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -30,7 +30,7 @@ class PresenceStatusRestServlet(RestServlet): PATTERNS = client_patterns("/presence/(?P[^/]*)/status", v1=True) def __init__(self, hs): - super(PresenceStatusRestServlet, self).__init__() + super().__init__() self.hs = hs self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index e7fe50ed72..b686cd671f 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -25,7 +25,7 @@ class ProfileDisplaynameRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P[^/]*)/displayname", v1=True) def __init__(self, hs): - super(ProfileDisplaynameRestServlet, self).__init__() + super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() @@ -73,7 +73,7 @@ class ProfileAvatarURLRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P[^/]*)/avatar_url", v1=True) def __init__(self, hs): - super(ProfileAvatarURLRestServlet, self).__init__() + super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() @@ -124,7 +124,7 @@ class ProfileRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P[^/]*)", v1=True) def __init__(self, hs): - super(ProfileRestServlet, self).__init__() + super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index ddf8ed5e9c..f9eecb7cf5 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -38,7 +38,7 @@ class PushRuleRestServlet(RestServlet): ) def __init__(self, hs): - super(PushRuleRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 5f65cb7d83..28dabf1c7a 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -44,7 +44,7 @@ class PushersRestServlet(RestServlet): PATTERNS = client_patterns("/pushers$", v1=True) def __init__(self, hs): - super(PushersRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -68,7 +68,7 @@ class PushersSetRestServlet(RestServlet): PATTERNS = client_patterns("/pushers/set$", v1=True) def __init__(self, hs): - super(PushersSetRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.notifier = hs.get_notifier() @@ -153,7 +153,7 @@ class PushersRemoveRestServlet(RestServlet): SUCCESS_HTML = b"You have been unsubscribed" def __init__(self, hs): - super(PushersRemoveRestServlet, self).__init__() + super().__init__() self.hs = hs self.notifier = hs.get_notifier() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 84baf3d59b..7e64a2e0fe 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -57,7 +57,7 @@ logger = logging.getLogger(__name__) class TransactionRestServlet(RestServlet): def __init__(self, hs): - super(TransactionRestServlet, self).__init__() + super().__init__() self.txns = HttpTransactionCache(hs) @@ -65,7 +65,7 @@ class RoomCreateRestServlet(TransactionRestServlet): # No PATTERN; we have custom dispatch rules here def __init__(self, hs): - super(RoomCreateRestServlet, self).__init__(hs) + super().__init__(hs) self._room_creation_handler = hs.get_room_creation_handler() self.auth = hs.get_auth() @@ -111,7 +111,7 @@ class RoomCreateRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomStateEventRestServlet, self).__init__(hs) + super().__init__(hs) self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() @@ -229,7 +229,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomSendEventRestServlet, self).__init__(hs) + super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() @@ -280,7 +280,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(TransactionRestServlet): def __init__(self, hs): - super(JoinRoomAliasServlet, self).__init__(hs) + super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -343,7 +343,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): PATTERNS = client_patterns("/publicRooms$", v1=True) def __init__(self, hs): - super(PublicRoomListRestServlet, self).__init__(hs) + super().__init__(hs) self.hs = hs self.auth = hs.get_auth() @@ -448,7 +448,7 @@ class RoomMemberListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/members$", v1=True) def __init__(self, hs): - super(RoomMemberListRestServlet, self).__init__() + super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() @@ -499,7 +499,7 @@ class JoinedRoomMemberListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/joined_members$", v1=True) def __init__(self, hs): - super(JoinedRoomMemberListRestServlet, self).__init__() + super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() @@ -518,7 +518,7 @@ class RoomMessageListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/messages$", v1=True) def __init__(self, hs): - super(RoomMessageListRestServlet, self).__init__() + super().__init__() self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() @@ -557,7 +557,7 @@ class RoomStateRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/state$", v1=True) def __init__(self, hs): - super(RoomStateRestServlet, self).__init__() + super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() @@ -577,7 +577,7 @@ class RoomInitialSyncRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/initialSync$", v1=True) def __init__(self, hs): - super(RoomInitialSyncRestServlet, self).__init__() + super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() @@ -596,7 +596,7 @@ class RoomEventServlet(RestServlet): ) def __init__(self, hs): - super(RoomEventServlet, self).__init__() + super().__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() @@ -628,7 +628,7 @@ class RoomEventContextServlet(RestServlet): ) def __init__(self, hs): - super(RoomEventContextServlet, self).__init__() + super().__init__() self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() self._event_serializer = hs.get_event_client_serializer() @@ -675,7 +675,7 @@ class RoomEventContextServlet(RestServlet): class RoomForgetRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomForgetRestServlet, self).__init__(hs) + super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -701,7 +701,7 @@ class RoomForgetRestServlet(TransactionRestServlet): # TODO: Needs unit testing class RoomMembershipRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomMembershipRestServlet, self).__init__(hs) + super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -792,7 +792,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): class RoomRedactEventRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomRedactEventRestServlet, self).__init__(hs) + super().__init__(hs) self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() @@ -841,7 +841,7 @@ class RoomTypingRestServlet(RestServlet): ) def __init__(self, hs): - super(RoomTypingRestServlet, self).__init__() + super().__init__() self.presence_handler = hs.get_presence_handler() self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() @@ -914,7 +914,7 @@ class SearchRestServlet(RestServlet): PATTERNS = client_patterns("/search$", v1=True) def __init__(self, hs): - super(SearchRestServlet, self).__init__() + super().__init__() self.handlers = hs.get_handlers() self.auth = hs.get_auth() @@ -935,7 +935,7 @@ class JoinedRoomsRestServlet(RestServlet): PATTERNS = client_patterns("/joined_rooms$", v1=True) def __init__(self, hs): - super(JoinedRoomsRestServlet, self).__init__() + super().__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 50277c6cf6..b8d491ca5c 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -25,7 +25,7 @@ class VoipRestServlet(RestServlet): PATTERNS = client_patterns("/voip/turnServer$", v1=True) def __init__(self, hs): - super(VoipRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index ade97a6708..c3ce0f6259 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -52,7 +52,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/password/email/requestToken$") def __init__(self, hs): - super(EmailPasswordRequestTokenRestServlet, self).__init__() + super().__init__() self.hs = hs self.datastore = hs.get_datastore() self.config = hs.config @@ -156,7 +156,7 @@ class PasswordRestServlet(RestServlet): PATTERNS = client_patterns("/account/password$") def __init__(self, hs): - super(PasswordRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() @@ -282,7 +282,7 @@ class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_patterns("/account/deactivate$") def __init__(self, hs): - super(DeactivateAccountRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() @@ -330,7 +330,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/email/requestToken$") def __init__(self, hs): - super(EmailThreepidRequestTokenRestServlet, self).__init__() + super().__init__() self.hs = hs self.config = hs.config self.identity_handler = hs.get_handlers().identity_handler @@ -427,7 +427,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): def __init__(self, hs): self.hs = hs - super(MsisdnThreepidRequestTokenRestServlet, self).__init__() + super().__init__() self.store = self.hs.get_datastore() self.identity_handler = hs.get_handlers().identity_handler @@ -606,7 +606,7 @@ class ThreepidRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid$") def __init__(self, hs): - super(ThreepidRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -662,7 +662,7 @@ class ThreepidAddRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/add$") def __init__(self, hs): - super(ThreepidAddRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -713,7 +713,7 @@ class ThreepidBindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/bind$") def __init__(self, hs): - super(ThreepidBindRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -742,7 +742,7 @@ class ThreepidUnbindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/unbind$") def __init__(self, hs): - super(ThreepidUnbindRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -773,7 +773,7 @@ class ThreepidDeleteRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/delete$") def __init__(self, hs): - super(ThreepidDeleteRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() @@ -852,7 +852,7 @@ class WhoamiRestServlet(RestServlet): PATTERNS = client_patterns("/account/whoami$") def __init__(self, hs): - super(WhoamiRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() async def on_GET(self, request): diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index c1d4cd0caf..87a5b1b86b 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -34,7 +34,7 @@ class AccountDataServlet(RestServlet): ) def __init__(self, hs): - super(AccountDataServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() @@ -86,7 +86,7 @@ class RoomAccountDataServlet(RestServlet): ) def __init__(self, hs): - super(RoomAccountDataServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index d06336ceea..bd7f9ae203 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -32,7 +32,7 @@ class AccountValidityRenewServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(AccountValidityRenewServlet, self).__init__() + super().__init__() self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() @@ -67,7 +67,7 @@ class AccountValiditySendMailServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(AccountValiditySendMailServlet, self).__init__() + super().__init__() self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 8e585e9153..097538f968 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -124,7 +124,7 @@ class AuthRestServlet(RestServlet): PATTERNS = client_patterns(r"/auth/(?P[\w\.]*)/fallback/web") def __init__(self, hs): - super(AuthRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index fe9d019c44..76879ac559 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -32,7 +32,7 @@ class CapabilitiesRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(CapabilitiesRestServlet, self).__init__() + super().__init__() self.hs = hs self.config = hs.config self.auth = hs.get_auth() diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index c0714fcfb1..7e174de692 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -35,7 +35,7 @@ class DevicesRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(DevicesRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() @@ -57,7 +57,7 @@ class DeleteDevicesRestServlet(RestServlet): PATTERNS = client_patterns("/delete_devices") def __init__(self, hs): - super(DeleteDevicesRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() @@ -102,7 +102,7 @@ class DeviceRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(DeviceRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index b28da017cd..7cc692643b 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -28,7 +28,7 @@ class GetFilterRestServlet(RestServlet): PATTERNS = client_patterns("/user/(?P[^/]*)/filter/(?P[^/]*)") def __init__(self, hs): - super(GetFilterRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.filtering = hs.get_filtering() @@ -64,7 +64,7 @@ class CreateFilterRestServlet(RestServlet): PATTERNS = client_patterns("/user/(?P[^/]*)/filter") def __init__(self, hs): - super(CreateFilterRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.filtering = hs.get_filtering() diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 13ecf7005d..a3bb095c2d 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -32,7 +32,7 @@ class GroupServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/profile$") def __init__(self, hs): - super(GroupServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -66,7 +66,7 @@ class GroupSummaryServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/summary$") def __init__(self, hs): - super(GroupSummaryServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -97,7 +97,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): ) def __init__(self, hs): - super(GroupSummaryRoomsCatServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -137,7 +137,7 @@ class GroupCategoryServlet(RestServlet): ) def __init__(self, hs): - super(GroupCategoryServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -181,7 +181,7 @@ class GroupCategoriesServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/categories/$") def __init__(self, hs): - super(GroupCategoriesServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -204,7 +204,7 @@ class GroupRoleServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/(?P[^/]+)$") def __init__(self, hs): - super(GroupRoleServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -248,7 +248,7 @@ class GroupRolesServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/$") def __init__(self, hs): - super(GroupRolesServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -279,7 +279,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): ) def __init__(self, hs): - super(GroupSummaryUsersRoleServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -317,7 +317,7 @@ class GroupRoomServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/rooms$") def __init__(self, hs): - super(GroupRoomServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -343,7 +343,7 @@ class GroupUsersServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/users$") def __init__(self, hs): - super(GroupUsersServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -366,7 +366,7 @@ class GroupInvitedUsersServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/invited_users$") def __init__(self, hs): - super(GroupInvitedUsersServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -389,7 +389,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/settings/m.join_policy$") def __init__(self, hs): - super(GroupSettingJoinPolicyServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.groups_handler = hs.get_groups_local_handler() @@ -413,7 +413,7 @@ class GroupCreateServlet(RestServlet): PATTERNS = client_patterns("/create_group$") def __init__(self, hs): - super(GroupCreateServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -444,7 +444,7 @@ class GroupAdminRoomsServlet(RestServlet): ) def __init__(self, hs): - super(GroupAdminRoomsServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -481,7 +481,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): ) def __init__(self, hs): - super(GroupAdminRoomsConfigServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -507,7 +507,7 @@ class GroupAdminUsersInviteServlet(RestServlet): ) def __init__(self, hs): - super(GroupAdminUsersInviteServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -536,7 +536,7 @@ class GroupAdminUsersKickServlet(RestServlet): ) def __init__(self, hs): - super(GroupAdminUsersKickServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -560,7 +560,7 @@ class GroupSelfLeaveServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/self/leave$") def __init__(self, hs): - super(GroupSelfLeaveServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -584,7 +584,7 @@ class GroupSelfJoinServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/self/join$") def __init__(self, hs): - super(GroupSelfJoinServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -608,7 +608,7 @@ class GroupSelfAcceptInviteServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/self/accept_invite$") def __init__(self, hs): - super(GroupSelfAcceptInviteServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -632,7 +632,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/self/update_publicity$") def __init__(self, hs): - super(GroupSelfUpdatePublicityServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -655,7 +655,7 @@ class PublicisedGroupsForUserServlet(RestServlet): PATTERNS = client_patterns("/publicised_groups/(?P[^/]*)$") def __init__(self, hs): - super(PublicisedGroupsForUserServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -676,7 +676,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): PATTERNS = client_patterns("/publicised_groups$") def __init__(self, hs): - super(PublicisedGroupsForUsersServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -700,7 +700,7 @@ class GroupsForUserServlet(RestServlet): PATTERNS = client_patterns("/joined_groups$") def __init__(self, hs): - super(GroupsForUserServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 24bb090822..7abd6ff333 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -64,7 +64,7 @@ class KeyUploadServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(KeyUploadServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() @@ -147,7 +147,7 @@ class KeyQueryServlet(RestServlet): Args: hs (synapse.server.HomeServer): """ - super(KeyQueryServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() @@ -177,7 +177,7 @@ class KeyChangesServlet(RestServlet): Args: hs (synapse.server.HomeServer): """ - super(KeyChangesServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() @@ -222,7 +222,7 @@ class OneTimeKeyServlet(RestServlet): PATTERNS = client_patterns("/keys/claim$") def __init__(self, hs): - super(OneTimeKeyServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() @@ -250,7 +250,7 @@ class SigningKeyUploadServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(SigningKeyUploadServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() @@ -308,7 +308,7 @@ class SignaturesUploadServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(SignaturesUploadServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index aa911d75ee..87063ec8b1 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -27,7 +27,7 @@ class NotificationsServlet(RestServlet): PATTERNS = client_patterns("/notifications$") def __init__(self, hs): - super(NotificationsServlet, self).__init__() + super().__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index 6ae9a5a8e9..5b996e2d63 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -60,7 +60,7 @@ class IdTokenServlet(RestServlet): EXPIRES_MS = 3600 * 1000 def __init__(self, hs): - super(IdTokenServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py index 968403cca4..68b27ff23a 100644 --- a/synapse/rest/client/v2_alpha/password_policy.py +++ b/synapse/rest/client/v2_alpha/password_policy.py @@ -30,7 +30,7 @@ class PasswordPolicyServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(PasswordPolicyServlet, self).__init__() + super().__init__() self.policy = hs.config.password_policy self.enabled = hs.config.password_policy_enabled diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index 67cbc37312..55c6688f52 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -26,7 +26,7 @@ class ReadMarkerRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/read_markers$") def __init__(self, hs): - super(ReadMarkerRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() self.read_marker_handler = hs.get_read_marker_handler() diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 92555bd4a9..6f7246a394 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -31,7 +31,7 @@ class ReceiptRestServlet(RestServlet): ) def __init__(self, hs): - super(ReceiptRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 0705718d00..ffa2dfce42 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -76,7 +76,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(EmailRegisterRequestTokenRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.config = hs.config @@ -174,7 +174,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(MsisdnRegisterRequestTokenRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler @@ -249,7 +249,7 @@ class RegistrationSubmitTokenServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RegistrationSubmitTokenServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.config = hs.config @@ -319,7 +319,7 @@ class UsernameAvailabilityRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(UsernameAvailabilityRestServlet, self).__init__() + super().__init__() self.hs = hs self.registration_handler = hs.get_registration_handler() self.ratelimiter = FederationRateLimiter( @@ -363,7 +363,7 @@ class RegisterRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RegisterRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index e29f49f7f5..18c75738f8 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -61,7 +61,7 @@ class RelationSendServlet(RestServlet): ) def __init__(self, hs): - super(RelationSendServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.event_creation_handler = hs.get_event_creation_handler() self.txns = HttpTransactionCache(hs) @@ -138,7 +138,7 @@ class RelationPaginationServlet(RestServlet): ) def __init__(self, hs): - super(RelationPaginationServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -233,7 +233,7 @@ class RelationAggregationPaginationServlet(RestServlet): ) def __init__(self, hs): - super(RelationAggregationPaginationServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.event_handler = hs.get_event_handler() @@ -311,7 +311,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ) def __init__(self, hs): - super(RelationAggregationGroupPaginationServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index e15927c4ea..215d619ca1 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -32,7 +32,7 @@ class ReportEventRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/report/(?P[^/]*)$") def __init__(self, hs): - super(ReportEventRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 59529707df..53de97923f 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -37,7 +37,7 @@ class RoomKeysServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RoomKeysServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() @@ -248,7 +248,7 @@ class RoomKeysNewVersionServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RoomKeysNewVersionServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() @@ -301,7 +301,7 @@ class RoomKeysVersionServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RoomKeysVersionServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index 39a5518614..bf030e0ff4 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -53,7 +53,7 @@ class RoomUpgradeRestServlet(RestServlet): ) def __init__(self, hs): - super(RoomUpgradeRestServlet, self).__init__() + super().__init__() self._hs = hs self._room_creation_handler = hs.get_room_creation_handler() self._auth = hs.get_auth() diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index db829f3098..bc4f43639a 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -36,7 +36,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(SendToDeviceRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs) diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py index 2492634dac..c866d5151c 100644 --- a/synapse/rest/client/v2_alpha/shared_rooms.py +++ b/synapse/rest/client/v2_alpha/shared_rooms.py @@ -34,7 +34,7 @@ class UserSharedRoomsServlet(RestServlet): ) def __init__(self, hs): - super(UserSharedRoomsServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.user_directory_active = hs.config.update_user_directory diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a0b00135e1..51e395cc64 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -74,7 +74,7 @@ class SyncRestServlet(RestServlet): ALLOWED_PRESENCE = {"online", "offline", "unavailable"} def __init__(self, hs): - super(SyncRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.sync_handler = hs.get_sync_handler() diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index a3f12e8a77..bf3a79db44 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -31,7 +31,7 @@ class TagListServlet(RestServlet): PATTERNS = client_patterns("/user/(?P[^/]*)/rooms/(?P[^/]*)/tags") def __init__(self, hs): - super(TagListServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -56,7 +56,7 @@ class TagServlet(RestServlet): ) def __init__(self, hs): - super(TagServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 23709960ad..0c127a1b5f 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -28,7 +28,7 @@ class ThirdPartyProtocolsServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/protocols") def __init__(self, hs): - super(ThirdPartyProtocolsServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() @@ -44,7 +44,7 @@ class ThirdPartyProtocolServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/protocol/(?P[^/]+)$") def __init__(self, hs): - super(ThirdPartyProtocolServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() @@ -65,7 +65,7 @@ class ThirdPartyUserServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/user(/(?P[^/]+))?$") def __init__(self, hs): - super(ThirdPartyUserServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() @@ -87,7 +87,7 @@ class ThirdPartyLocationServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/location(/(?P[^/]+))?$") def __init__(self, hs): - super(ThirdPartyLocationServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 83f3b6b70a..79317c74ba 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -28,7 +28,7 @@ class TokenRefreshRestServlet(RestServlet): PATTERNS = client_patterns("/tokenrefresh") def __init__(self, hs): - super(TokenRefreshRestServlet, self).__init__() + super().__init__() async def on_POST(self, request): raise AuthError(403, "tokenrefresh is no longer supported.") diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index bef91a2d3e..ad598cefe0 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -31,7 +31,7 @@ class UserDirectorySearchRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(UserDirectorySearchRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 24ac57f35d..d5018afbda 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -28,7 +28,7 @@ class VersionsRestServlet(RestServlet): PATTERNS = [re.compile("^/_matrix/client/versions$")] def __init__(self, hs): - super(VersionsRestServlet, self).__init__() + super().__init__() self.config = hs.config def on_GET(self, request): diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 2ae2fbd5d7..ccb3384db9 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -172,7 +172,7 @@ class DataStore( else: self._cache_id_gen = None - super(DataStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._presence_on_startup = self._get_active_presence(db_conn) diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 5f1a2b9aa6..c5a36990e4 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -42,7 +42,7 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): "AccountDataAndTagsChangeCache", account_max ) - super(AccountDataWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) @abc.abstractmethod def get_max_account_data_stream_id(self): @@ -313,7 +313,7 @@ class AccountDataStore(AccountDataWorkerStore): ], ) - super(AccountDataStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_max_account_data_stream_id(self) -> int: """Get the current max stream id for the private user data stream diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 454c0bc50c..85f6b1e3fd 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -52,7 +52,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) - super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_app_services(self): return self.services_cache diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index c2fc847fbc..239c7a949c 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -31,7 +31,7 @@ LAST_SEEN_GRANULARITY = 120 * 1000 class ClientIpBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( "user_ips_device_index", @@ -358,7 +358,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): name="client_ip_last_seen", keylen=4, max_entries=50000 ) - super(ClientIpStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.user_ips_max_age diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 0044433110..e71217a41f 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -283,7 +283,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" def __init__(self, database: DatabasePool, db_conn, hs): - super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( "device_inbox_stream_index", @@ -313,7 +313,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" def __init__(self, database: DatabasePool, db_conn, hs): - super(DeviceInboxStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 306fc6947c..c04374e43d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -701,7 +701,7 @@ class DeviceWorkerStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( "device_lists_stream_idx", @@ -826,7 +826,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(DeviceStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 4c3c162acf..6d3689c09e 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -600,7 +600,7 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" def __init__(self, database: DatabasePool, db_conn, hs): - super(EventFederationStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 7805fb814e..62f1738732 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -68,7 +68,7 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn self.stream_ordering_month_ago = None @@ -661,7 +661,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" def __init__(self, database: DatabasePool, db_conn, hs): - super(EventPushActionsStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index e53c6373a8..5e4af2eb51 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -29,7 +29,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" def __init__(self, database: DatabasePool, db_conn, hs): - super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index cd3739c16c..de9e8d1dc6 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -75,7 +75,7 @@ class EventRedactBehaviour(Names): class EventsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(EventsWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) if isinstance(database.engine, PostgresEngine): # If we're using Postgres than we can use `MultiWriterIdGenerator` diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 1d76c761a6..cc538c5c10 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -24,9 +24,7 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(MediaRepositoryBackgroundUpdateStore, self).__init__( - database, db_conn, hs - ) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( update_name="local_media_repository_url_idx", @@ -94,7 +92,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" def __init__(self, database: DatabasePool, db_conn, hs): - super(MediaRepositoryStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: """Get the metadata for a local piece of media diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 1d793d3deb..e0cedd1aac 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -28,7 +28,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs @@ -120,7 +120,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._mau_stats_only = hs.config.mau_stats_only diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index b7a8d34ce1..e20a16f907 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -77,7 +77,7 @@ class PushRulesWorkerStore( """ def __init__(self, database: DatabasePool, db_conn, hs): - super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: self._push_rules_stream_id_gen = StreamIdGenerator( diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 6568bddd81..f880b5e562 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -39,7 +39,7 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() @@ -386,7 +386,7 @@ class ReceiptsStore(ReceiptsWorkerStore): db_conn, "receipts_linearized", "stream_id" ) - super(ReceiptsStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 01f20c03c2..675e81fe34 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -36,7 +36,7 @@ logger = logging.getLogger(__name__) class RegistrationWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config self.clock = hs.get_clock() @@ -764,7 +764,7 @@ class RegistrationWorkerStore(SQLBaseStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.clock = hs.get_clock() self.config = hs.config @@ -892,7 +892,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationStore(RegistrationBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RegistrationStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 127588ce4c..bd6f9553c6 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -69,7 +69,7 @@ class RoomSortOrder(Enum): class RoomWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config @@ -863,7 +863,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config @@ -1074,7 +1074,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 91a8b43da3..4fa8767b01 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -55,7 +55,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Is the current_state_events.membership up to date? Or is the # background update still running? @@ -819,7 +819,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) @@ -973,7 +973,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomMemberStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def forget(self, user_id: str, room_id: str) -> None: """Indicate that user_id wishes to discard history for room_id.""" diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index f01cf2fd02..e34fce6281 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -89,7 +89,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" def __init__(self, database: DatabasePool, db_conn, hs): - super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) if not hs.config.enable_search: return @@ -342,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): class SearchStore(SearchBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SearchStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 5c6168e301..3c1e33819b 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -56,7 +56,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def get_room_version(self, room_id: str) -> RoomVersion: """Get the room_version of a given room @@ -320,7 +320,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" def __init__(self, database: DatabasePool, db_conn, hs): - super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -506,4 +506,4 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(StateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 30840dbbaa..d7816a8606 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -61,7 +61,7 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} class StatsStore(StateDeltasStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(StatsStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.server_name = hs.hostname self.clock = self.hs.get_clock() diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 7dbe11513b..5dac78e574 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -266,7 +266,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): """ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): - super(StreamWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() self._send_federation = hs.should_send_federation() diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 091367006e..99cffff50c 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -48,7 +48,7 @@ class TransactionStore(SQLBaseStore): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(TransactionStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index f2f9a5799a..5a390ff2f6 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -38,7 +38,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): SHARE_PRIVATE_WORKING_SET = 500 def __init__(self, database: DatabasePool, db_conn, hs): - super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -564,7 +564,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): SHARE_PRIVATE_WORKING_SET = 500 def __init__(self, database: DatabasePool, db_conn, hs): - super(UserDirectoryStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def remove_from_user_dir(self, user_id: str) -> None: def _remove_from_user_dir_txn(txn): diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 139085b672..acb24e33af 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -181,7 +181,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" def __init__(self, database: DatabasePool, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._background_deduplicate_state, diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index e924f1ca3b..bec3780a32 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -52,7 +52,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(StateGroupDataStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Originally the state store used a single DictionaryCache to cache the # event IDs for the state types in a given state group to avoid hammering diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 631654f297..da24ba0470 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -94,7 +94,7 @@ class SynapseManhole(ColoredManhole): """Overrides connectionMade to create our own ManholeInterpreter""" def connectionMade(self): - super(SynapseManhole, self).connectionMade() + super().connectionMade() # replace the manhole interpreter with our own impl self.interpreter = SynapseManholeInterpreter(self, self.namespace) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 79869aaa44..a5cc9d0551 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -45,7 +45,7 @@ class NotRetryingDestination(Exception): """ msg = "Not retrying server %s." % (destination,) - super(NotRetryingDestination, self).__init__(msg) + super().__init__(msg) self.retry_last_ts = retry_last_ts self.retry_interval = retry_interval diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 210ddcbb88..366dcfb670 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -30,7 +30,7 @@ from tests import unittest, utils class E2eKeysHandlerTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): - super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.hs = None # type: synapse.server.HomeServer self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 3362050ce0..7adde9b9de 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -47,7 +47,7 @@ room_keys = { class E2eRoomKeysHandlerTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): - super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.hs = None # type: synapse.server.HomeServer self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 561258a356..bc578411d6 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -58,7 +58,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): # Patch up the equality operator for events so that we can check # whether lists of events match using assertEquals self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)] - return super(SlavedEventStoreTestCase, self).setUp() + return super().setUp() def prepare(self, *args, **kwargs): super().prepare(*args, **kwargs) diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index b090bb974c..dcd65c2a50 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -21,7 +21,7 @@ from tests import unittest class WellKnownTests(unittest.HomeserverTestCase): def setUp(self): - super(WellKnownTests, self).setUp() + super().setUp() # replace the JsonResource with a WellKnownResource self.resource = WellKnownResource(self.hs) diff --git a/tests/server.py b/tests/server.py index 61ec670155..b404ad4e2a 100644 --- a/tests/server.py +++ b/tests/server.py @@ -260,7 +260,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): return succeed(lookups[name]) self.nameResolver = SimpleResolverComplexifier(FakeResolver()) - super(ThreadedMemoryReactorClock, self).__init__() + super().__init__() def listenUDP(self, port, protocol, interface="", maxPacketSize=8196): p = udp.Port(port, protocol, interface, maxPacketSize, self) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index cb808d4de4..46f94914ff 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -413,7 +413,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(TestTransactionStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.TestCase): diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 34ae8c9da7..ecb00f4e02 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -23,7 +23,7 @@ import tests.utils class DeviceStoreTestCase(tests.unittest.TestCase): def __init__(self, *args, **kwargs): - super(DeviceStoreTestCase, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.store = None # type: synapse.storage.DataStore @defer.inlineCallbacks diff --git a/tests/test_state.py b/tests/test_state.py index 2d58467932..80b0ccbc40 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -125,7 +125,7 @@ class StateGroupStore: class DictObj(dict): def __init__(self, **kwargs): - super(DictObj, self).__init__(kwargs) + super().__init__(kwargs) self.__dict__ = self diff --git a/tests/unittest.py b/tests/unittest.py index 128dd4e19c..dabf69cff4 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -92,7 +92,7 @@ class TestCase(unittest.TestCase): root logger's logging level while that test (case|method) runs.""" def __init__(self, methodName, *args, **kwargs): - super(TestCase, self).__init__(methodName, *args, **kwargs) + super().__init__(methodName, *args, **kwargs) method = getattr(self, methodName) -- cgit 1.5.1 From 36efbcaf511790d6f1dd7df2260900f07489bda6 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Fri, 18 Sep 2020 14:59:13 +0100 Subject: Catch-up after Federation Outage (bonus): Catch-up on Synapse Startup (#8322) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Olivier Wilkinson (reivilibre) Co-authored-by: Patrick Cloke * Fix _set_destination_retry_timings This came about because the code assumed that retry_interval could not be NULL — which has been challenged by catch-up. --- changelog.d/8230.bugfix | 1 + changelog.d/8230.misc | 1 - changelog.d/8247.bugfix | 1 + changelog.d/8247.misc | 1 - changelog.d/8258.bugfix | 1 + changelog.d/8258.misc | 1 - changelog.d/8322.bugfix | 1 + synapse/federation/sender/__init__.py | 51 +++++++++++++ synapse/storage/databases/main/transactions.py | 66 ++++++++++++++++- tests/federation/test_federation_catch_up.py | 99 ++++++++++++++++++++++++++ 10 files changed, 218 insertions(+), 5 deletions(-) create mode 100644 changelog.d/8230.bugfix delete mode 100644 changelog.d/8230.misc create mode 100644 changelog.d/8247.bugfix delete mode 100644 changelog.d/8247.misc create mode 100644 changelog.d/8258.bugfix delete mode 100644 changelog.d/8258.misc create mode 100644 changelog.d/8322.bugfix (limited to 'tests') diff --git a/changelog.d/8230.bugfix b/changelog.d/8230.bugfix new file mode 100644 index 0000000000..532d0e22fe --- /dev/null +++ b/changelog.d/8230.bugfix @@ -0,0 +1 @@ +Fix messages over federation being lost until an event is sent into the same room. diff --git a/changelog.d/8230.misc b/changelog.d/8230.misc deleted file mode 100644 index bf0ba76730..0000000000 --- a/changelog.d/8230.misc +++ /dev/null @@ -1 +0,0 @@ -Track the latest event for every destination and room for catch-up after federation outage. diff --git a/changelog.d/8247.bugfix b/changelog.d/8247.bugfix new file mode 100644 index 0000000000..532d0e22fe --- /dev/null +++ b/changelog.d/8247.bugfix @@ -0,0 +1 @@ +Fix messages over federation being lost until an event is sent into the same room. diff --git a/changelog.d/8247.misc b/changelog.d/8247.misc deleted file mode 100644 index 3c27803be4..0000000000 --- a/changelog.d/8247.misc +++ /dev/null @@ -1 +0,0 @@ -Track the `stream_ordering` of the last successfully-sent event to every destination, so we can use this information to 'catch up' a remote server after an outage. diff --git a/changelog.d/8258.bugfix b/changelog.d/8258.bugfix new file mode 100644 index 0000000000..532d0e22fe --- /dev/null +++ b/changelog.d/8258.bugfix @@ -0,0 +1 @@ +Fix messages over federation being lost until an event is sent into the same room. diff --git a/changelog.d/8258.misc b/changelog.d/8258.misc deleted file mode 100644 index 3c27803be4..0000000000 --- a/changelog.d/8258.misc +++ /dev/null @@ -1 +0,0 @@ -Track the `stream_ordering` of the last successfully-sent event to every destination, so we can use this information to 'catch up' a remote server after an outage. diff --git a/changelog.d/8322.bugfix b/changelog.d/8322.bugfix new file mode 100644 index 0000000000..532d0e22fe --- /dev/null +++ b/changelog.d/8322.bugfix @@ -0,0 +1 @@ +Fix messages over federation being lost until an event is sent into the same room. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 41a726878d..8bb17b3a05 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -55,6 +55,15 @@ sent_pdus_destination_dist_total = Counter( "Total number of PDUs queued for sending across all destinations", ) +# Time (in s) after Synapse's startup that we will begin to wake up destinations +# that have catch-up outstanding. +CATCH_UP_STARTUP_DELAY_SEC = 15 + +# Time (in s) to wait in between waking up each destination, i.e. one destination +# will be woken up every seconds after Synapse's startup until we have woken +# every destination has outstanding catch-up. +CATCH_UP_STARTUP_INTERVAL_SEC = 5 + class FederationSender: def __init__(self, hs: "synapse.server.HomeServer"): @@ -125,6 +134,14 @@ class FederationSender: 1000.0 / hs.config.federation_rr_transactions_per_room_per_second ) + # wake up destinations that have outstanding PDUs to be caught up + self._catchup_after_startup_timer = self.clock.call_later( + CATCH_UP_STARTUP_DELAY_SEC, + run_as_background_process, + "wake_destinations_needing_catchup", + self._wake_destinations_needing_catchup, + ) + def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination @@ -560,3 +577,37 @@ class FederationSender: # Dummy implementation for case where federation sender isn't offloaded # to a worker. return [], 0, False + + async def _wake_destinations_needing_catchup(self): + """ + Wakes up destinations that need catch-up and are not currently being + backed off from. + + In order to reduce load spikes, adds a delay between each destination. + """ + + last_processed = None # type: Optional[str] + + while True: + destinations_to_wake = await self.store.get_catch_up_outstanding_destinations( + last_processed + ) + + if not destinations_to_wake: + # finished waking all destinations! + self._catchup_after_startup_timer = None + break + + destinations_to_wake = [ + d + for d in destinations_to_wake + if self._federation_shard_config.should_handle(self._instance_name, d) + ] + + for last_processed in destinations_to_wake: + logger.info( + "Destination %s has outstanding catch-up, waking up.", + last_processed, + ) + self.wake_destination(last_processed) + await self.clock.sleep(CATCH_UP_STARTUP_INTERVAL_SEC) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 99cffff50c..97aed1500e 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -218,6 +218,7 @@ class TransactionStore(SQLBaseStore): retry_interval = EXCLUDED.retry_interval WHERE EXCLUDED.retry_interval = 0 + OR destinations.retry_interval IS NULL OR destinations.retry_interval < EXCLUDED.retry_interval """ @@ -249,7 +250,11 @@ class TransactionStore(SQLBaseStore): "retry_interval": retry_interval, }, ) - elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: + elif ( + retry_interval == 0 + or prev_row["retry_interval"] is None + or prev_row["retry_interval"] < retry_interval + ): self.db_pool.simple_update_one_txn( txn, "destinations", @@ -397,7 +402,7 @@ class TransactionStore(SQLBaseStore): @staticmethod def _get_catch_up_room_event_ids_txn( - txn, destination: str, last_successful_stream_ordering: int, + txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int, ) -> List[str]: q = """ SELECT event_id FROM destination_rooms @@ -412,3 +417,60 @@ class TransactionStore(SQLBaseStore): ) event_ids = [row[0] for row in txn] return event_ids + + async def get_catch_up_outstanding_destinations( + self, after_destination: Optional[str] + ) -> List[str]: + """ + Gets at most 25 destinations which have outstanding PDUs to be caught up, + and are not being backed off from + Args: + after_destination: + If provided, all destinations must be lexicographically greater + than this one. + + Returns: + list of up to 25 destinations with outstanding catch-up. + These are the lexicographically first destinations which are + lexicographically greater than after_destination (if provided). + """ + time = self.hs.get_clock().time_msec() + + return await self.db_pool.runInteraction( + "get_catch_up_outstanding_destinations", + self._get_catch_up_outstanding_destinations_txn, + time, + after_destination, + ) + + @staticmethod + def _get_catch_up_outstanding_destinations_txn( + txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str] + ) -> List[str]: + q = """ + SELECT destination FROM destinations + WHERE destination IN ( + SELECT destination FROM destination_rooms + WHERE destination_rooms.stream_ordering > + destinations.last_successful_stream_ordering + ) + AND destination > ? + AND ( + retry_last_ts IS NULL OR + retry_last_ts + retry_interval < ? + ) + ORDER BY destination + LIMIT 25 + """ + txn.execute( + q, + ( + # everything is lexicographically greater than "" so this gives + # us the first batch of up to 25. + after_destination or "", + now_time_ms, + ), + ) + + destinations = [row[0] for row in txn] + return destinations diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index cc52c3dfac..1a3ccb263d 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -321,3 +321,102 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): per_dest_queue._last_successful_stream_ordering, event_5.internal_metadata.stream_ordering, ) + + @override_config({"send_federation": True}) + def test_catch_up_on_synapse_startup(self): + """ + Tests the behaviour of get_catch_up_outstanding_destinations and + _wake_destinations_needing_catchup. + """ + + # list of sorted server names (note that there are more servers than the batch + # size used in get_catch_up_outstanding_destinations). + server_names = ["server%02d" % number for number in range(42)] + ["zzzerver"] + + # ARRANGE: + # - a local user (u1) + # - a room which u1 is joined to (and remote users @user:serverXX are + # joined to) + + # mark the remotes as online + self.is_online = True + + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room_id = self.helper.create_room_as("u1", tok=u1_token) + + for server_name in server_names: + self.get_success( + event_injection.inject_member_event( + self.hs, room_id, "@user:%s" % server_name, "join" + ) + ) + + # create an event + self.helper.send(room_id, "deary me!", tok=u1_token) + + # ASSERT: + # - All servers are up to date so none should have outstanding catch-up + outstanding_when_successful = self.get_success( + self.hs.get_datastore().get_catch_up_outstanding_destinations(None) + ) + self.assertEqual(outstanding_when_successful, []) + + # ACT: + # - Make the remote servers unreachable + self.is_online = False + + # - Mark zzzerver as being backed-off from + now = self.clock.time_msec() + self.get_success( + self.hs.get_datastore().set_destination_retry_timings( + "zzzerver", now, now, 24 * 60 * 60 * 1000 # retry in 1 day + ) + ) + + # - Send an event + self.helper.send(room_id, "can anyone hear me?", tok=u1_token) + + # ASSERT (get_catch_up_outstanding_destinations): + # - all remotes are outstanding + # - they are returned in batches of 25, in order + outstanding_1 = self.get_success( + self.hs.get_datastore().get_catch_up_outstanding_destinations(None) + ) + + self.assertEqual(len(outstanding_1), 25) + self.assertEqual(outstanding_1, server_names[0:25]) + + outstanding_2 = self.get_success( + self.hs.get_datastore().get_catch_up_outstanding_destinations( + outstanding_1[-1] + ) + ) + self.assertNotIn("zzzerver", outstanding_2) + self.assertEqual(len(outstanding_2), 17) + self.assertEqual(outstanding_2, server_names[25:-1]) + + # ACT: call _wake_destinations_needing_catchup + + # patch wake_destination to just count the destinations instead + woken = [] + + def wake_destination_track(destination): + woken.append(destination) + + self.hs.get_federation_sender().wake_destination = wake_destination_track + + # cancel the pre-existing timer for _wake_destinations_needing_catchup + # this is because we are calling it manually rather than waiting for it + # to be called automatically + self.hs.get_federation_sender()._catchup_after_startup_timer.cancel() + + self.get_success( + self.hs.get_federation_sender()._wake_destinations_needing_catchup(), by=5.0 + ) + + # ASSERT (_wake_destinations_needing_catchup): + # - all remotes are woken up, save for zzzerver + self.assertNotIn("zzzerver", woken) + # - all destinations are woken exactly once; they appear once in woken. + self.assertCountEqual(woken, server_names[:-1]) -- cgit 1.5.1 From d688b4bafca58dfff1be35615d6ff1e202d47cc6 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 18 Sep 2020 16:26:36 +0200 Subject: Admin API for querying rooms where a user is a member (#8306) Add a new admin API `GET /_synapse/admin/v1/users//joined_rooms` to list all rooms where a user is a member. --- changelog.d/8306.feature | 1 + docs/admin_api/user_admin_api.rst | 37 +++++++++++++++ synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/users.py | 26 +++++++++++ tests/rest/admin/test_user.py | 96 ++++++++++++++++++++++++++++++++++++++- 5 files changed, 160 insertions(+), 2 deletions(-) create mode 100644 changelog.d/8306.feature (limited to 'tests') diff --git a/changelog.d/8306.feature b/changelog.d/8306.feature new file mode 100644 index 0000000000..5c23da4030 --- /dev/null +++ b/changelog.d/8306.feature @@ -0,0 +1 @@ +Add an admin API for querying rooms where a user is a member. Contributed by @dklimpel. \ No newline at end of file diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index e21c78a9c6..7ca902faba 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -304,6 +304,43 @@ To use it, you will need to authenticate by providing an ``access_token`` for a server admin: see `README.rst `_. +List room memberships of an user +================================ +Gets a list of all ``room_id`` that a specific ``user_id`` is member. + +The API is:: + + GET /_synapse/admin/v1/users//joined_rooms + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst `_. + +A response body like the following is returned: + +.. code:: json + + { + "joined_rooms": [ + "!DuGcnbhHGaSZQoNQR:matrix.org", + "!ZtSaPCawyWtxfWiIy:matrix.org" + ], + "total": 2 + } + +**Parameters** + +The following parameters should be set in the URL: + +- ``user_id`` - fully qualified: for example, ``@user:server.com``. + +**Response** + +The following fields are returned in the JSON response body: + +- ``joined_rooms`` - An array of ``room_id``. +- ``total`` - Number of rooms. + + User devices ============ diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index abf362c7b7..4a75c06480 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -49,6 +49,7 @@ from synapse.rest.admin.users import ( ResetPasswordRestServlet, SearchUsersRestServlet, UserAdminServlet, + UserMembershipRestServlet, UserRegisterServlet, UserRestServletV2, UsersRestServlet, @@ -209,6 +210,7 @@ def register_servlets(hs, http_server): SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) + UserMembershipRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) DeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 0f537031c4..20dc1d0e05 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -683,3 +683,29 @@ class UserAdminServlet(RestServlet): await self.store.set_server_admin(target_user, set_admin_to) return 200, {} + + +class UserMembershipRestServlet(RestServlet): + """ + Get room list of an user. + """ + + PATTERNS = admin_patterns("/users/(?P[^/]+)/joined_rooms$") + + def __init__(self, hs): + self.is_mine = hs.is_mine + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + if not self.is_mine(UserID.from_string(user_id)): + raise SynapseError(400, "Can only lookup local users") + + room_ids = await self.store.get_rooms_for_user(user_id) + if not room_ids: + raise NotFoundError("User not found") + + ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} + return 200, ret diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index b8b7758d24..f96011fc1c 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -22,8 +22,8 @@ from mock import Mock import synapse.rest.admin from synapse.api.constants import UserTypes -from synapse.api.errors import HttpResponseException, ResourceLimitError -from synapse.rest.client.v1 import login +from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError +from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import sync from tests import unittest @@ -995,3 +995,95 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Ensure they're still alive self.assertEqual(0, channel.json_body["deactivated"]) + + +class UserMembershipRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.url = "/_synapse/admin/v1/users/%s/joined_rooms" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to list rooms of an user without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + request, channel = self.make_request( + "GET", self.url, access_token=other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms" + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms" + + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + def test_get_rooms(self): + """ + Tests that a normal lookup for rooms is successfully + """ + # Create rooms and join + other_user_tok = self.login("user", "pass") + number_rooms = 5 + for n in range(number_rooms): + self.helper.create_room_as(self.other_user, tok=other_user_tok) + + # Get rooms + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(number_rooms, channel.json_body["total"]) + self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) -- cgit 1.5.1 From 37ca5924bddccc37521798236339b539677d101f Mon Sep 17 00:00:00 2001 From: Dionysis Grigoropoulos Date: Tue, 22 Sep 2020 13:42:55 +0300 Subject: Create function to check for long names in devices (#8364) * Create a new function to verify that the length of a device name is under a certain threshold. * Refactor old code and tests to use said function. * Verify device name length during registration of device * Add a test for the above Signed-off-by: Dionysis Grigoropoulos --- changelog.d/8364.bugfix | 2 ++ synapse/handlers/device.py | 30 ++++++++++++++++++++++++------ tests/handlers/test_device.py | 11 +++++++++++ tests/rest/admin/test_device.py | 2 +- 4 files changed, 38 insertions(+), 7 deletions(-) create mode 100644 changelog.d/8364.bugfix (limited to 'tests') diff --git a/changelog.d/8364.bugfix b/changelog.d/8364.bugfix new file mode 100644 index 0000000000..7b82cbc388 --- /dev/null +++ b/changelog.d/8364.bugfix @@ -0,0 +1,2 @@ +Fix a bug where during device registration the length of the device name wasn't +limited. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 55a9787439..4149520d6c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional from synapse.api import errors from synapse.api.constants import EventTypes from synapse.api.errors import ( + Codes, FederationDeniedError, HttpResponseException, RequestSendFailed, @@ -265,6 +266,24 @@ class DeviceHandler(DeviceWorkerHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) + def _check_device_name_length(self, name: str): + """ + Checks whether a device name is longer than the maximum allowed length. + + Args: + name: The name of the device. + + Raises: + SynapseError: if the device name is too long. + """ + if name and len(name) > MAX_DEVICE_DISPLAY_NAME_LEN: + raise SynapseError( + 400, + "Device display name is too long (max %i)" + % (MAX_DEVICE_DISPLAY_NAME_LEN,), + errcode=Codes.TOO_LARGE, + ) + async def check_device_registered( self, user_id, device_id, initial_device_display_name=None ): @@ -282,6 +301,9 @@ class DeviceHandler(DeviceWorkerHandler): Returns: str: device id (generated if none was supplied) """ + + self._check_device_name_length(initial_device_display_name) + if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -397,12 +419,8 @@ class DeviceHandler(DeviceWorkerHandler): # Reject a new displayname which is too long. new_display_name = content.get("display_name") - if new_display_name and len(new_display_name) > MAX_DEVICE_DISPLAY_NAME_LEN: - raise SynapseError( - 400, - "Device display name is too long (max %i)" - % (MAX_DEVICE_DISPLAY_NAME_LEN,), - ) + + self._check_device_name_length(new_display_name) try: await self.store.update_device( diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 6aa322bf3a..969d44c787 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -35,6 +35,17 @@ class DeviceTestCase(unittest.HomeserverTestCase): # These tests assume that it starts 1000 seconds in. self.reactor.advance(1000) + def test_device_is_created_with_invalid_name(self): + self.get_failure( + self.handler.check_device_registered( + user_id="@boris:foo", + device_id="foo", + initial_device_display_name="a" + * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1), + ), + synapse.api.errors.SynapseError, + ) + def test_device_is_created_if_doesnt_exist(self): res = self.get_success( self.handler.check_device_registered( diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index faa7f381a9..92c9058887 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -221,7 +221,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"]) # Ensure the display name was not updated. request, channel = self.make_request( -- cgit 1.5.1 From 4da01f9c614f36a293235d6a1fd3602d550f2001 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 22 Sep 2020 19:15:04 +0200 Subject: Admin API for reported events (#8217) Add an admin API to read entries of table `event_reports`. API: `GET /_synapse/admin/v1/event_reports` --- changelog.d/8217.feature | 1 + docs/admin_api/event_reports.rst | 129 +++++++++++ synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/event_reports.py | 88 ++++++++ synapse/storage/databases/main/room.py | 95 ++++++++ tests/rest/admin/test_event_reports.py | 382 +++++++++++++++++++++++++++++++++ 6 files changed, 697 insertions(+) create mode 100644 changelog.d/8217.feature create mode 100644 docs/admin_api/event_reports.rst create mode 100644 synapse/rest/admin/event_reports.py create mode 100644 tests/rest/admin/test_event_reports.py (limited to 'tests') diff --git a/changelog.d/8217.feature b/changelog.d/8217.feature new file mode 100644 index 0000000000..899cbf14ef --- /dev/null +++ b/changelog.d/8217.feature @@ -0,0 +1 @@ +Add an admin API `GET /_synapse/admin/v1/event_reports` to read entries of table `event_reports`. Contributed by @dklimpel. \ No newline at end of file diff --git a/docs/admin_api/event_reports.rst b/docs/admin_api/event_reports.rst new file mode 100644 index 0000000000..461be01230 --- /dev/null +++ b/docs/admin_api/event_reports.rst @@ -0,0 +1,129 @@ +Show reported events +==================== + +This API returns information about reported events. + +The api is:: + + GET /_synapse/admin/v1/event_reports?from=0&limit=10 + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst `_. + +It returns a JSON body like the following: + +.. code:: jsonc + + { + "event_reports": [ + { + "content": { + "reason": "foo", + "score": -100 + }, + "event_id": "$bNUFCwGzWca1meCGkjp-zwslF-GfVcXukvRLI1_FaVY", + "event_json": { + "auth_events": [ + "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M", + "$oggsNXxzPFRE3y53SUNd7nsj69-QzKv03a1RucHu-ws" + ], + "content": { + "body": "matrix.org: This Week in Matrix", + "format": "org.matrix.custom.html", + "formatted_body": "matrix.org:
This Week in Matrix", + "msgtype": "m.notice" + }, + "depth": 546, + "hashes": { + "sha256": "xK1//xnmvHJIOvbgXlkI8eEqdvoMmihVDJ9J4SNlsAw" + }, + "origin": "matrix.org", + "origin_server_ts": 1592291711430, + "prev_events": [ + "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M" + ], + "prev_state": [], + "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org", + "sender": "@foobar:matrix.org", + "signatures": { + "matrix.org": { + "ed25519:a_JaEG": "cs+OUKW/iHx5pEidbWxh0UiNNHwe46Ai9LwNz+Ah16aWDNszVIe2gaAcVZfvNsBhakQTew51tlKmL2kspXk/Dg" + } + }, + "type": "m.room.message", + "unsigned": { + "age_ts": 1592291711430, + } + }, + "id": 2, + "reason": "foo", + "received_ts": 1570897107409, + "room_alias": "#alias1:matrix.org", + "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org", + "sender": "@foobar:matrix.org", + "user_id": "@foo:matrix.org" + }, + { + "content": { + "reason": "bar", + "score": -100 + }, + "event_id": "$3IcdZsDaN_En-S1DF4EMCy3v4gNRKeOJs8W5qTOKj4I", + "event_json": { + // hidden items + // see above + }, + "id": 3, + "reason": "bar", + "received_ts": 1598889612059, + "room_alias": "#alias2:matrix.org", + "room_id": "!eGvUQuTCkHGVwNMOjv:matrix.org", + "sender": "@foobar:matrix.org", + "user_id": "@bar:matrix.org" + } + ], + "next_token": 2, + "total": 4 + } + +To paginate, check for ``next_token`` and if present, call the endpoint again +with ``from`` set to the value of ``next_token``. This will return a new page. + +If the endpoint does not return a ``next_token`` then there are no more +reports to paginate through. + +**URL parameters:** + +- ``limit``: integer - Is optional but is used for pagination, + denoting the maximum number of items to return in this call. Defaults to ``100``. +- ``from``: integer - Is optional but used for pagination, + denoting the offset in the returned results. This should be treated as an opaque value and + not explicitly set to anything other than the return value of ``next_token`` from a previous call. + Defaults to ``0``. +- ``dir``: string - Direction of event report order. Whether to fetch the most recent first (``b``) or the + oldest first (``f``). Defaults to ``b``. +- ``user_id``: string - Is optional and filters to only return users with user IDs that contain this value. + This is the user who reported the event and wrote the reason. +- ``room_id``: string - Is optional and filters to only return rooms with room IDs that contain this value. + +**Response** + +The following fields are returned in the JSON response body: + +- ``id``: integer - ID of event report. +- ``received_ts``: integer - The timestamp (in milliseconds since the unix epoch) when this report was sent. +- ``room_id``: string - The ID of the room in which the event being reported is located. +- ``event_id``: string - The ID of the reported event. +- ``user_id``: string - This is the user who reported the event and wrote the reason. +- ``reason``: string - Comment made by the ``user_id`` in this report. May be blank. +- ``content``: object - Content of reported event. + + - ``reason``: string - Comment made by the ``user_id`` in this report. May be blank. + - ``score``: integer - Content is reported based upon a negative score, where -100 is "most offensive" and 0 is "inoffensive". + +- ``sender``: string - This is the ID of the user who sent the original message/event that was reported. +- ``room_alias``: string - The alias of the room. ``null`` if the room does not have a canonical alias set. +- ``event_json``: object - Details of the original event that was reported. +- ``next_token``: integer - Indication for pagination. See above. +- ``total``: integer - Total number of event reports related to the query (``user_id`` and ``room_id``). + diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 4a75c06480..5c5f00b213 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -31,6 +31,7 @@ from synapse.rest.admin.devices import ( DeviceRestServlet, DevicesRestServlet, ) +from synapse.rest.admin.event_reports import EventReportsRestServlet from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet @@ -216,6 +217,7 @@ def register_servlets(hs, http_server): DeviceRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) DeleteDevicesRestServlet(hs).register(http_server) + EventReportsRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py new file mode 100644 index 0000000000..5b8d0594cd --- /dev/null +++ b/synapse/rest/admin/event_reports.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin + +logger = logging.getLogger(__name__) + + +class EventReportsRestServlet(RestServlet): + """ + List all reported events that are known to the homeserver. Results are returned + in a dictionary containing report information. Supports pagination. + The requester must have administrator access in Synapse. + + GET /_synapse/admin/v1/event_reports + returns: + 200 OK with list of reports if success otherwise an error. + + Args: + The parameters `from` and `limit` are required only for pagination. + By default, a `limit` of 100 is used. + The parameter `dir` can be used to define the order of results. + The parameter `user_id` can be used to filter by user id. + The parameter `room_id` can be used to filter by room id. + Returns: + A list of reported events and an integer representing the total number of + reported events that exist given this query + """ + + PATTERNS = admin_patterns("/event_reports$") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request): + await assert_requester_is_admin(self.auth, request) + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + direction = parse_string(request, "dir", default="b") + user_id = parse_string(request, "user_id") + room_id = parse_string(request, "room_id") + + if start < 0: + raise SynapseError( + 400, + "The start parameter must be a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "The limit parameter must be a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if direction not in ("f", "b"): + raise SynapseError( + 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + ) + + event_reports, total = await self.store.get_event_reports_paginate( + start, limit, direction, user_id, room_id + ) + ret = {"event_reports": event_reports, "total": total} + if (start + limit) < total: + ret["next_token"] = start + len(event_reports) + + return 200, ret diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index bd6f9553c6..3ee097abf7 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1328,6 +1328,101 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): desc="add_event_report", ) + async def get_event_reports_paginate( + self, + start: int, + limit: int, + direction: str = "b", + user_id: Optional[str] = None, + room_id: Optional[str] = None, + ) -> Tuple[List[Dict[str, Any]], int]: + """Retrieve a paginated list of event reports + + Args: + start: event offset to begin the query from + limit: number of rows to retrieve + direction: Whether to fetch the most recent first (`"b"`) or the + oldest first (`"f"`) + user_id: search for user_id. Ignored if user_id is None + room_id: search for room_id. Ignored if room_id is None + Returns: + event_reports: json list of event reports + count: total number of event reports matching the filter criteria + """ + + def _get_event_reports_paginate_txn(txn): + filters = [] + args = [] + + if user_id: + filters.append("er.user_id LIKE ?") + args.extend(["%" + user_id + "%"]) + if room_id: + filters.append("er.room_id LIKE ?") + args.extend(["%" + room_id + "%"]) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + + sql = """ + SELECT COUNT(*) as total_event_reports + FROM event_reports AS er + {} + """.format( + where_clause + ) + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = """ + SELECT + er.id, + er.received_ts, + er.room_id, + er.event_id, + er.user_id, + er.reason, + er.content, + events.sender, + room_aliases.room_alias, + event_json.json AS event_json + FROM event_reports AS er + LEFT JOIN room_aliases + ON room_aliases.room_id = er.room_id + JOIN events + ON events.event_id = er.event_id + JOIN event_json + ON event_json.event_id = er.event_id + {where_clause} + ORDER BY er.received_ts {order} + LIMIT ? + OFFSET ? + """.format( + where_clause=where_clause, order=order, + ) + + args += [limit, start] + txn.execute(sql, args) + event_reports = self.db_pool.cursor_to_dict(txn) + + if count > 0: + for row in event_reports: + try: + row["content"] = db_to_json(row["content"]) + row["event_json"] = db_to_json(row["event_json"]) + except Exception: + continue + + return event_reports, count + + return await self.db_pool.runInteraction( + "get_event_reports_paginate", _get_event_reports_paginate_txn + ) + def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py new file mode 100644 index 0000000000..bf79086f78 --- /dev/null +++ b/tests/rest/admin/test_event_reports.py @@ -0,0 +1,382 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import report_event + +from tests import unittest + + +class EventReportsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + report_event.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id1 = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok, is_public=True + ) + self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok) + + self.room_id2 = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok, is_public=True + ) + self.helper.join(self.room_id2, user=self.admin_user, tok=self.admin_user_tok) + + # Two rooms and two users. Every user sends and reports every room event + for i in range(5): + self._create_event_and_report( + room_id=self.room_id1, user_tok=self.other_user_tok, + ) + for i in range(5): + self._create_event_and_report( + room_id=self.room_id2, user_tok=self.other_user_tok, + ) + for i in range(5): + self._create_event_and_report( + room_id=self.room_id1, user_tok=self.admin_user_tok, + ) + for i in range(5): + self._create_event_and_report( + room_id=self.room_id2, user_tok=self.admin_user_tok, + ) + + self.url = "/_synapse/admin/v1/event_reports" + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + + request, channel = self.make_request( + "GET", self.url, access_token=self.other_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_default_success(self): + """ + Testing list of reported events + """ + + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 20) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["event_reports"]) + + def test_limit(self): + """ + Testing list of reported events with limit + """ + + request, channel = self.make_request( + "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 5) + self.assertEqual(channel.json_body["next_token"], 5) + self._check_fields(channel.json_body["event_reports"]) + + def test_from(self): + """ + Testing list of reported events with a defined starting point (from) + """ + + request, channel = self.make_request( + "GET", self.url + "?from=5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["event_reports"]) + + def test_limit_and_from(self): + """ + Testing list of reported events with a defined starting point and limit + """ + + request, channel = self.make_request( + "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(channel.json_body["next_token"], 15) + self.assertEqual(len(channel.json_body["event_reports"]), 10) + self._check_fields(channel.json_body["event_reports"]) + + def test_filter_room(self): + """ + Testing list of reported events with a filter of room + """ + + request, channel = self.make_request( + "GET", + self.url + "?room_id=%s" % self.room_id1, + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 10) + self.assertEqual(len(channel.json_body["event_reports"]), 10) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["event_reports"]) + + for report in channel.json_body["event_reports"]: + self.assertEqual(report["room_id"], self.room_id1) + + def test_filter_user(self): + """ + Testing list of reported events with a filter of user + """ + + request, channel = self.make_request( + "GET", + self.url + "?user_id=%s" % self.other_user, + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 10) + self.assertEqual(len(channel.json_body["event_reports"]), 10) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["event_reports"]) + + for report in channel.json_body["event_reports"]: + self.assertEqual(report["user_id"], self.other_user) + + def test_filter_user_and_room(self): + """ + Testing list of reported events with a filter of user and room + """ + + request, channel = self.make_request( + "GET", + self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 5) + self.assertEqual(len(channel.json_body["event_reports"]), 5) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["event_reports"]) + + for report in channel.json_body["event_reports"]: + self.assertEqual(report["user_id"], self.other_user) + self.assertEqual(report["room_id"], self.room_id1) + + def test_valid_search_order(self): + """ + Testing search order. Order by timestamps. + """ + + # fetch the most recent first, largest timestamp + request, channel = self.make_request( + "GET", self.url + "?dir=b", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 20) + report = 1 + while report < len(channel.json_body["event_reports"]): + self.assertGreaterEqual( + channel.json_body["event_reports"][report - 1]["received_ts"], + channel.json_body["event_reports"][report]["received_ts"], + ) + report += 1 + + # fetch the oldest first, smallest timestamp + request, channel = self.make_request( + "GET", self.url + "?dir=f", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 20) + report = 1 + while report < len(channel.json_body["event_reports"]): + self.assertLessEqual( + channel.json_body["event_reports"][report - 1]["received_ts"], + channel.json_body["event_reports"][report]["received_ts"], + ) + report += 1 + + def test_invalid_search_order(self): + """ + Testing that a invalid search order returns a 400 + """ + + request, channel = self.make_request( + "GET", self.url + "?dir=bar", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual("Unknown direction: bar", channel.json_body["error"]) + + def test_limit_is_negative(self): + """ + Testing that a negative list parameter returns a 400 + """ + + request, channel = self.make_request( + "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_from_is_negative(self): + """ + Testing that a negative from parameter returns a 400 + """ + + request, channel = self.make_request( + "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + # `next_token` does not appear + # Number of results is the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 20) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 20) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 19) + self.assertEqual(channel.json_body["next_token"], 19) + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + request, channel = self.make_request( + "GET", self.url + "?from=19", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def _create_event_and_report(self, room_id, user_tok): + """Create and report events + """ + resp = self.helper.send(room_id, tok=user_tok) + event_id = resp["event_id"] + + request, channel = self.make_request( + "POST", + "rooms/%s/report/%s" % (room_id, event_id), + json.dumps({"score": -100, "reason": "this makes me sad"}), + access_token=user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def _check_fields(self, content): + """Checks that all attributes are present in a event report + """ + for c in content: + self.assertIn("id", c) + self.assertIn("received_ts", c) + self.assertIn("room_id", c) + self.assertIn("event_id", c) + self.assertIn("user_id", c) + self.assertIn("reason", c) + self.assertIn("content", c) + self.assertIn("sender", c) + self.assertIn("room_alias", c) + self.assertIn("event_json", c) + self.assertIn("score", c["content"]) + self.assertIn("reason", c["content"]) + self.assertIn("auth_events", c["event_json"]) + self.assertIn("type", c["event_json"]) + self.assertIn("room_id", c["event_json"]) + self.assertIn("sender", c["event_json"]) + self.assertIn("content", c["event_json"]) -- cgit 1.5.1 From 8998217540bc41975e64e44c507632361ca95698 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 22 Sep 2020 19:19:01 +0200 Subject: Fixed a bug with reactivating users with the admin API (#8362) Fixes: #8359 Trying to reactivate a user with the admin API (`PUT /_synapse/admin/v2/users/`) causes an internal server error. Seems to be a regression in #8033. --- changelog.d/8362.bugfix | 1 + synapse/storage/databases/main/user_erasure_store.py | 2 +- tests/rest/admin/test_user.py | 14 ++++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8362.bugfix (limited to 'tests') diff --git a/changelog.d/8362.bugfix b/changelog.d/8362.bugfix new file mode 100644 index 0000000000..4e50067c87 --- /dev/null +++ b/changelog.d/8362.bugfix @@ -0,0 +1 @@ +Fixed a regression in v1.19.0 with reactivating users through the admin API. diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index 2f7c95fc74..f9575b1f1f 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -100,7 +100,7 @@ class UserErasureStore(UserErasureWorkerStore): return # They are there, delete them. - self.simple_delete_one_txn( + self.db_pool.simple_delete_one_txn( txn, "erased_users", keyvalues={"user_id": user_id} ) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index f96011fc1c..98d0623734 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -874,6 +874,10 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self._is_erased("@user:test", False) + d = self.store.mark_user_erased("@user:test") + self.assertIsNone(self.get_success(d)) + self._is_erased("@user:test", True) # Attempt to reactivate the user (without a password). request, channel = self.make_request( @@ -906,6 +910,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) + self._is_erased("@user:test", False) def test_set_user_as_admin(self): """ @@ -996,6 +1001,15 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Ensure they're still alive self.assertEqual(0, channel.json_body["deactivated"]) + def _is_erased(self, user_id, expect): + """Assert that the user is erased or not + """ + d = self.store.is_user_erased(user_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertFalse(self.get_success(d)) + class UserMembershipRestTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From cbabb312e0b59090e5a8cf9e7e016a8618e62867 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Sep 2020 16:11:18 +0100 Subject: Use `async with` for ID gens (#8383) This will allow us to hit the DB after we've finished using the generated stream ID. --- changelog.d/8383.misc | 1 + synapse/storage/databases/main/account_data.py | 4 +- synapse/storage/databases/main/deviceinbox.py | 4 +- synapse/storage/databases/main/devices.py | 6 +- synapse/storage/databases/main/end_to_end_keys.py | 2 +- synapse/storage/databases/main/events.py | 6 +- synapse/storage/databases/main/group_server.py | 2 +- synapse/storage/databases/main/presence.py | 4 +- synapse/storage/databases/main/push_rule.py | 8 +- synapse/storage/databases/main/pusher.py | 4 +- synapse/storage/databases/main/receipts.py | 2 +- synapse/storage/databases/main/room.py | 6 +- synapse/storage/databases/main/tags.py | 4 +- synapse/storage/util/id_generators.py | 130 +++++++++++++--------- tests/storage/test_id_generators.py | 66 ++++++----- 15 files changed, 144 insertions(+), 105 deletions(-) create mode 100644 changelog.d/8383.misc (limited to 'tests') diff --git a/changelog.d/8383.misc b/changelog.d/8383.misc new file mode 100644 index 0000000000..cb8318bf57 --- /dev/null +++ b/changelog.d/8383.misc @@ -0,0 +1 @@ +Refactor ID generators to use `async with` syntax. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index c5a36990e4..ef81d73573 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -339,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. @@ -387,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index e71217a41f..d42faa3f1f 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) - with await self._device_inbox_id_gen.get_next() as stream_id: + async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id @@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) txn, stream_id, local_messages_by_user_then_device ) - with await self._device_inbox_id_gen.get_next() as stream_id: + async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index c04374e43d..fdf394c612 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore): THe new stream ID. """ - with await self._device_list_id_gen.get_next() as stream_id: + async with self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -1093,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return - with await self._device_list_id_gen.get_next_mult( + async with self._device_list_id_gen.get_next_mult( len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( @@ -1108,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return stream_ids[-1] context = get_active_span_text_map() - with await self._device_list_id_gen.get_next_mult( + async with self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index c8df0bcb3f..22e1ed15d0 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key (dict): the key data """ - with await self._cross_signing_id_gen.get_next() as stream_id: + async with self._cross_signing_id_gen.get_next() as stream_id: return await self.db_pool.runInteraction( "add_e2e_cross_signing_key", self._set_e2e_cross_signing_key_txn, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 9a80f419e3..7723d82496 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -156,15 +156,15 @@ class PersistEventsStore: # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. if backfilled: - stream_ordering_manager = await self._backfill_id_gen.get_next_mult( + stream_ordering_manager = self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) else: - stream_ordering_manager = await self._stream_id_gen.get_next_mult( + stream_ordering_manager = self._stream_id_gen.get_next_mult( len(events_and_contexts) ) - with stream_ordering_manager as stream_orderings: + async with stream_ordering_manager as stream_orderings: for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index ccfbb2135e..7218191965 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - with await self._group_updates_id_gen.get_next() as next_id: + async with self._group_updates_id_gen.get_next() as next_id: res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index c9f655dfb7..dbbb99cb95 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): async def update_presence(self, presence_states): - stream_ordering_manager = await self._presence_id_gen.get_next_mult( + stream_ordering_manager = self._presence_id_gen.get_next_mult( len(presence_states) ) - with stream_ordering_manager as stream_orderings: + async with stream_ordering_manager as stream_orderings: await self.db_pool.runInteraction( "update_presence", self._update_presence_txn, diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index e20a16f907..711d5aa23d 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore): ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() if before or after: @@ -585,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore): txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -616,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore): Raises: NotFoundError if the rule does not exist. """ - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( "_set_push_rule_enabled_txn", @@ -754,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore): data={"actions": actions_json}, ) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index c388468273..df8609b97b 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore): last_stream_ordering, profile_tag="", ) -> None: - with await self._pushers_id_gen.get_next() as stream_id: + async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry await self.db_pool.simple_upsert( @@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore): }, ) - with await self._pushers_id_gen.get_next() as stream_id: + async with self._pushers_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "delete_pusher", delete_pusher_txn, stream_id ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index f880b5e562..c79ddff680 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -524,7 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "insert_receipt_conv", graph_to_linear ) - with await self._receipts_id_gen.get_next() as stream_id: + async with self._receipts_id_gen.get_next() as stream_id: event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 3ee097abf7..3c7630857f 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1137,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "store_room_txn", store_room_txn, next_id ) @@ -1204,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) @@ -1284,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 96ffe26cc9..9f120d3cb6 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 1de2b91587..b0353ac2dc 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -12,14 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import contextlib import heapq import logging import threading from collections import deque -from typing import Dict, List, Set +from contextlib import contextmanager +from typing import Dict, List, Optional, Set, Union +import attr from typing_extensions import Deque from synapse.storage.database import DatabasePool, LoggingTransaction @@ -86,7 +86,7 @@ class StreamIdGenerator: upwards, -1 to grow downwards. Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ @@ -101,10 +101,10 @@ class StreamIdGenerator: ) self._unfinished_ids = deque() # type: Deque[int] - async def get_next(self): + def get_next(self): """ Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -113,7 +113,7 @@ class StreamIdGenerator: self._unfinished_ids.append(next_id) - @contextlib.contextmanager + @contextmanager def manager(): try: yield next_id @@ -121,12 +121,12 @@ class StreamIdGenerator: with self._lock: self._unfinished_ids.remove(next_id) - return manager() + return _AsyncCtxManagerWrapper(manager()) - async def get_next_mult(self, n): + def get_next_mult(self, n): """ Usage: - with await stream_id_gen.get_next(n) as stream_ids: + async with stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: @@ -140,7 +140,7 @@ class StreamIdGenerator: for next_id in next_ids: self._unfinished_ids.append(next_id) - @contextlib.contextmanager + @contextmanager def manager(): try: yield next_ids @@ -149,7 +149,7 @@ class StreamIdGenerator: for next_id in next_ids: self._unfinished_ids.remove(next_id) - return manager() + return _AsyncCtxManagerWrapper(manager()) def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or @@ -282,59 +282,23 @@ class MultiWriterIdGenerator: def _load_next_mult_id_txn(self, txn, n: int) -> List[int]: return self._sequence_gen.get_next_mult_txn(txn, n) - async def get_next(self): + def get_next(self): """ Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn) - - # Assert the fetched ID is actually greater than what we currently - # believe the ID to be. If not, then the sequence and table have got - # out of sync somehow. - with self._lock: - assert self._current_positions.get(self._instance_name, 0) < next_id - - self._unfinished_ids.add(next_id) - - @contextlib.contextmanager - def manager(): - try: - # Multiply by the return factor so that the ID has correct sign. - yield self._return_factor * next_id - finally: - self._mark_id_as_finished(next_id) - return manager() + return _MultiWriterCtxManager(self) - async def get_next_mult(self, n: int): + def get_next_mult(self, n: int): """ Usage: - with await stream_id_gen.get_next_mult(5) as stream_ids: + async with stream_id_gen.get_next_mult(5) as stream_ids: # ... persist events ... """ - next_ids = await self._db.runInteraction( - "_load_next_mult_id", self._load_next_mult_id_txn, n - ) - # Assert the fetched ID is actually greater than any ID we've already - # seen. If not, then the sequence and table have got out of sync - # somehow. - with self._lock: - assert max(self._current_positions.values(), default=0) < min(next_ids) - - self._unfinished_ids.update(next_ids) - - @contextlib.contextmanager - def manager(): - try: - yield [self._return_factor * i for i in next_ids] - finally: - for i in next_ids: - self._mark_id_as_finished(i) - - return manager() + return _MultiWriterCtxManager(self, n) def get_next_txn(self, txn: LoggingTransaction): """ @@ -482,3 +446,61 @@ class MultiWriterIdGenerator: # There was a gap in seen positions, so there is nothing more to # do. break + + +@attr.s(slots=True) +class _AsyncCtxManagerWrapper: + """Helper class to convert a plain context manager to an async one. + + This is mainly useful if you have a plain context manager but the interface + requires an async one. + """ + + inner = attr.ib() + + async def __aenter__(self): + return self.inner.__enter__() + + async def __aexit__(self, exc_type, exc, tb): + return self.inner.__exit__(exc_type, exc, tb) + + +@attr.s(slots=True) +class _MultiWriterCtxManager: + """Async context manager returned by MultiWriterIdGenerator + """ + + id_gen = attr.ib(type=MultiWriterIdGenerator) + multiple_ids = attr.ib(type=Optional[int], default=None) + stream_ids = attr.ib(type=List[int], factory=list) + + async def __aenter__(self) -> Union[int, List[int]]: + self.stream_ids = await self.id_gen._db.runInteraction( + "_load_next_mult_id", + self.id_gen._load_next_mult_id_txn, + self.multiple_ids or 1, + ) + + # Assert the fetched ID is actually greater than any ID we've already + # seen. If not, then the sequence and table have got out of sync + # somehow. + with self.id_gen._lock: + assert max(self.id_gen._current_positions.values(), default=0) < min( + self.stream_ids + ) + + self.id_gen._unfinished_ids.update(self.stream_ids) + + if self.multiple_ids is None: + return self.stream_ids[0] * self.id_gen._return_factor + else: + return [i * self.id_gen._return_factor for i in self.stream_ids] + + async def __aexit__(self, exc_type, exc, tb): + for i in self.stream_ids: + self.id_gen._mark_id_as_finished(i) + + if exc_type is not None: + return False + + return False diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 20636fc400..fb8f5bc255 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -111,7 +111,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # advanced after we leave the context manager. async def _get_next_async(): - with await id_gen.get_next() as stream_id: + async with id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) self.assertEqual(id_gen.get_positions(), {"master": 7}) @@ -139,10 +139,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): ctx3 = self.get_success(id_gen.get_next()) ctx4 = self.get_success(id_gen.get_next()) - s1 = ctx1.__enter__() - s2 = ctx2.__enter__() - s3 = ctx3.__enter__() - s4 = ctx4.__enter__() + s1 = self.get_success(ctx1.__aenter__()) + s2 = self.get_success(ctx2.__aenter__()) + s3 = self.get_success(ctx3.__aenter__()) + s4 = self.get_success(ctx4.__aenter__()) self.assertEqual(s1, 8) self.assertEqual(s2, 9) @@ -152,22 +152,22 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - ctx2.__exit__(None, None, None) + self.get_success(ctx2.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - ctx1.__exit__(None, None, None) + self.get_success(ctx1.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 9}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) - ctx4.__exit__(None, None, None) + self.get_success(ctx4.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 9}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) - ctx3.__exit__(None, None, None) + self.get_success(ctx3.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 11}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) @@ -190,7 +190,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # advanced after we leave the context manager. async def _get_next_async(): - with await first_id_gen.get_next() as stream_id: + async with first_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) self.assertEqual( @@ -208,7 +208,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # stream ID async def _get_next_async(): - with await second_id_gen.get_next() as stream_id: + async with second_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 9) self.assertEqual( @@ -305,9 +305,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) self.assertEqual(id_gen.get_persisted_upto_position(), 3) - with self.get_success(id_gen.get_next()) as stream_id: - self.assertEqual(stream_id, 6) - self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + async def _get_next_async(): + async with id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 6) + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + self.get_success(_get_next_async()) self.assertEqual(id_gen.get_persisted_upto_position(), 6) @@ -373,16 +377,22 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """ id_gen = self._create_id_generator() - with self.get_success(id_gen.get_next()) as stream_id: - self._insert_row("master", stream_id) + async def _get_next_async(): + async with id_gen.get_next() as stream_id: + self._insert_row("master", stream_id) + + self.get_success(_get_next_async()) self.assertEqual(id_gen.get_positions(), {"master": -1}) self.assertEqual(id_gen.get_current_token_for_writer("master"), -1) self.assertEqual(id_gen.get_persisted_upto_position(), -1) - with self.get_success(id_gen.get_next_mult(3)) as stream_ids: - for stream_id in stream_ids: - self._insert_row("master", stream_id) + async def _get_next_async2(): + async with id_gen.get_next_mult(3) as stream_ids: + for stream_id in stream_ids: + self._insert_row("master", stream_id) + + self.get_success(_get_next_async2()) self.assertEqual(id_gen.get_positions(), {"master": -4}) self.assertEqual(id_gen.get_current_token_for_writer("master"), -4) @@ -402,18 +412,24 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen_1 = self._create_id_generator("first") id_gen_2 = self._create_id_generator("second") - with self.get_success(id_gen_1.get_next()) as stream_id: - self._insert_row("first", stream_id) - id_gen_2.advance("first", stream_id) + async def _get_next_async(): + async with id_gen_1.get_next() as stream_id: + self._insert_row("first", stream_id) + id_gen_2.advance("first", stream_id) + + self.get_success(_get_next_async()) self.assertEqual(id_gen_1.get_positions(), {"first": -1}) self.assertEqual(id_gen_2.get_positions(), {"first": -1}) self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) - with self.get_success(id_gen_2.get_next()) as stream_id: - self._insert_row("second", stream_id) - id_gen_1.advance("second", stream_id) + async def _get_next_async2(): + async with id_gen_2.get_next() as stream_id: + self._insert_row("second", stream_id) + id_gen_1.advance("second", stream_id) + + self.get_success(_get_next_async2()) self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) -- cgit 1.5.1 From ac11fcbbb8ccfeb4c72b5aae9faef28469109277 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 24 Sep 2020 13:24:17 +0100 Subject: Add EventStreamPosition type (#8388) The idea is to remove some of the places we pass around `int`, where it can represent one of two things: 1. the position of an event in the stream; or 2. a token that partitions the stream, used as part of the stream tokens. The valid operations are then: 1. did a position happen before or after a token; 2. get all events that happened before or after a token; and 3. get all events between two tokens. (Note that we don't want to allow other operations as we want to change the tokens to be vector clocks rather than simple ints) --- changelog.d/8388.misc | 1 + synapse/handlers/federation.py | 16 +++++--- synapse/handlers/message.py | 6 +-- synapse/handlers/sync.py | 10 ++--- synapse/notifier.py | 55 ++++++++++++++------------ synapse/replication/tcp/client.py | 12 ++++-- synapse/storage/databases/main/roommember.py | 14 ++++--- synapse/storage/persist_events.py | 14 ++++--- synapse/storage/roommember.py | 2 +- synapse/types.py | 15 +++++++ tests/replication/slave/storage/test_events.py | 12 ++++-- 11 files changed, 100 insertions(+), 57 deletions(-) create mode 100644 changelog.d/8388.misc (limited to 'tests') diff --git a/changelog.d/8388.misc b/changelog.d/8388.misc new file mode 100644 index 0000000000..aaaef88b66 --- /dev/null +++ b/changelog.d/8388.misc @@ -0,0 +1 @@ +Add `EventStreamPosition` type. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ea9264e751..9f773aefa7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -74,6 +74,8 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( JsonDict, MutableStateMap, + PersistedEventPosition, + RoomStreamToken, StateMap, UserID, get_domain_from_id, @@ -2956,7 +2958,7 @@ class FederationHandler(BaseHandler): ) return result["max_stream_id"] else: - max_stream_id = await self.storage.persistence.persist_events( + max_stream_token = await self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) @@ -2967,12 +2969,12 @@ class FederationHandler(BaseHandler): if not backfilled: # Never notify for backfilled events for event, _ in event_and_contexts: - await self._notify_persisted_event(event, max_stream_id) + await self._notify_persisted_event(event, max_stream_token) - return max_stream_id + return max_stream_token.stream async def _notify_persisted_event( - self, event: EventBase, max_stream_id: int + self, event: EventBase, max_stream_token: RoomStreamToken ) -> None: """Checks to see if notifier/pushers should be notified about the event or not. @@ -2998,9 +3000,11 @@ class FederationHandler(BaseHandler): elif event.internal_metadata.is_outlier(): return - event_stream_id = event.internal_metadata.stream_ordering + event_pos = PersistedEventPosition( + self._instance_name, event.internal_metadata.stream_ordering + ) self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users + event, event_pos, max_stream_token, extra_users=extra_users ) async def _clean_room_for_join(self, room_id: str) -> None: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 6ee559fd1d..ee271e85e5 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1138,7 +1138,7 @@ class EventCreationHandler: if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - event_stream_id, max_stream_id = await self.storage.persistence.persist_event( + event_pos, max_stream_token = await self.storage.persistence.persist_event( event, context=context ) @@ -1149,7 +1149,7 @@ class EventCreationHandler: def _notify(): try: self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users + event, event_pos, max_stream_token, extra_users=extra_users ) except Exception: logger.exception("Error notifying about new room event") @@ -1161,7 +1161,7 @@ class EventCreationHandler: # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) - return event_stream_id + return event_pos.stream async def _bump_active_time(self, user: UserID) -> None: try: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9b3a4f638b..e948efef2e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -967,7 +967,7 @@ class SyncHandler: raise NotImplementedError() else: joined_room_ids = await self.get_rooms_for_user_at( - user_id, now_token.room_stream_id + user_id, now_token.room_key ) sync_result_builder = SyncResultBuilder( sync_config, @@ -1916,7 +1916,7 @@ class SyncHandler: raise Exception("Unrecognized rtype: %r", room_builder.rtype) async def get_rooms_for_user_at( - self, user_id: str, stream_ordering: int + self, user_id: str, room_key: RoomStreamToken ) -> FrozenSet[str]: """Get set of joined rooms for a user at the given stream ordering. @@ -1942,15 +1942,15 @@ class SyncHandler: # If the membership's stream ordering is after the given stream # ordering, we need to go and work out if the user was in the room # before. - for room_id, membership_stream_ordering in joined_rooms: - if membership_stream_ordering <= stream_ordering: + for room_id, event_pos in joined_rooms: + if not event_pos.persisted_after(room_key): joined_room_ids.add(room_id) continue logger.info("User joined room after current token: %s", room_id) extrems = await self.store.get_forward_extremeties_for_room( - room_id, stream_ordering + room_id, event_pos.stream ) users_in_room = await self.state.get_current_users_in_room(room_id, extrems) if user_id in users_in_room: diff --git a/synapse/notifier.py b/synapse/notifier.py index a8fd3ef886..441b3d15e2 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -42,7 +42,13 @@ from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.streams.config import PaginationConfig -from synapse.types import Collection, RoomStreamToken, StreamToken, UserID +from synapse.types import ( + Collection, + PersistedEventPosition, + RoomStreamToken, + StreamToken, + UserID, +) from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client @@ -187,7 +193,7 @@ class Notifier: self.store = hs.get_datastore() self.pending_new_room_events = ( [] - ) # type: List[Tuple[int, EventBase, Collection[UserID]]] + ) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]] # Called when there are new things to stream over replication self.replication_callbacks = [] # type: List[Callable[[], None]] @@ -246,8 +252,8 @@ class Notifier: def on_new_room_event( self, event: EventBase, - room_stream_id: int, - max_room_stream_id: int, + event_pos: PersistedEventPosition, + max_room_stream_token: RoomStreamToken, extra_users: Collection[UserID] = [], ): """ Used by handlers to inform the notifier something has happened @@ -261,16 +267,16 @@ class Notifier: until all previous events have been persisted before notifying the client streams. """ - self.pending_new_room_events.append((room_stream_id, event, extra_users)) - self._notify_pending_new_room_events(max_room_stream_id) + self.pending_new_room_events.append((event_pos, event, extra_users)) + self._notify_pending_new_room_events(max_room_stream_token) self.notify_replication() - def _notify_pending_new_room_events(self, max_room_stream_id: int): + def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken): """Notify for the room events that were queued waiting for a previous event to be persisted. Args: - max_room_stream_id: The highest stream_id below which all + max_room_stream_token: The highest stream_id below which all events have been persisted. """ pending = self.pending_new_room_events @@ -279,11 +285,9 @@ class Notifier: users = set() # type: Set[UserID] rooms = set() # type: Set[str] - for room_stream_id, event, extra_users in pending: - if room_stream_id > max_room_stream_id: - self.pending_new_room_events.append( - (room_stream_id, event, extra_users) - ) + for event_pos, event, extra_users in pending: + if event_pos.persisted_after(max_room_stream_token): + self.pending_new_room_events.append((event_pos, event, extra_users)) else: if ( event.type == EventTypes.Member @@ -296,39 +300,38 @@ class Notifier: if users or rooms: self.on_new_event( - "room_key", - RoomStreamToken(None, max_room_stream_id), - users=users, - rooms=rooms, + "room_key", max_room_stream_token, users=users, rooms=rooms, ) - self._on_updated_room_token(max_room_stream_id) + self._on_updated_room_token(max_room_stream_token) - def _on_updated_room_token(self, max_room_stream_id: int): + def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken): """Poke services that might care that the room position has been updated. """ # poke any interested application service. run_as_background_process( - "_notify_app_services", self._notify_app_services, max_room_stream_id + "_notify_app_services", self._notify_app_services, max_room_stream_token ) run_as_background_process( - "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_id + "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_token ) if self.federation_sender: - self.federation_sender.notify_new_events(max_room_stream_id) + self.federation_sender.notify_new_events(max_room_stream_token.stream) - async def _notify_app_services(self, max_room_stream_id: int): + async def _notify_app_services(self, max_room_stream_token: RoomStreamToken): try: - await self.appservice_handler.notify_interested_services(max_room_stream_id) + await self.appservice_handler.notify_interested_services( + max_room_stream_token.stream + ) except Exception: logger.exception("Error notifying application services of event") - async def _notify_pusher_pool(self, max_room_stream_id: int): + async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): try: - await self._pusher_pool.on_new_notifications(max_room_stream_id) + await self._pusher_pool.on_new_notifications(max_room_stream_token.stream) except Exception: logger.exception("Error pusher pool of event") diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e82b9e386f..55af3d41ea 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import ( EventsStreamEventRow, EventsStreamRow, ) -from synapse.types import UserID +from synapse.types import PersistedEventPosition, RoomStreamToken, UserID from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -151,8 +151,14 @@ class ReplicationDataHandler: extra_users = () # type: Tuple[UserID, ...] if event.type == EventTypes.Member: extra_users = (UserID.from_string(event.state_key),) - max_token = self.store.get_room_max_stream_ordering() - self.notifier.on_new_room_event(event, token, max_token, extra_users) + + max_token = RoomStreamToken( + None, self.store.get_room_max_stream_ordering() + ) + event_pos = PersistedEventPosition(instance_name, token) + self.notifier.on_new_room_event( + event, event_pos, max_token, extra_users + ) # Notify any waiting deferreds. The list is ordered by position so we # just iterate through the list until we reach a position that is diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 4fa8767b01..86ffe2479e 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set @@ -37,7 +36,7 @@ from synapse.storage.roommember import ( ProfileInfo, RoomsForUser, ) -from synapse.types import Collection, get_domain_from_id +from synapse.types import Collection, PersistedEventPosition, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -387,7 +386,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # for rooms the server is participating in. if self._current_state_events_membership_up_to_date: sql = """ - SELECT room_id, e.stream_ordering + SELECT room_id, e.instance_name, e.stream_ordering FROM current_state_events AS c INNER JOIN events AS e USING (room_id, event_id) WHERE @@ -397,7 +396,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ else: sql = """ - SELECT room_id, e.stream_ordering + SELECT room_id, e.instance_name, e.stream_ordering FROM current_state_events AS c INNER JOIN room_memberships AS m USING (room_id, event_id) INNER JOIN events AS e USING (room_id, event_id) @@ -408,7 +407,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ txn.execute(sql, (user_id, Membership.JOIN)) - return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) + return frozenset( + GetRoomsForUserWithStreamOrdering( + room_id, PersistedEventPosition(instance, stream_id) + ) + for room_id, instance, stream_id in txn + ) async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index d89f6ed128..603cd7d825 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState -from synapse.types import Collection, StateMap +from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap from synapse.util.async_helpers import ObservableDeferred from synapse.util.metrics import Measure @@ -190,6 +190,7 @@ class EventsPersistenceStorage: self.persist_events_store = stores.persist_events self._clock = hs.get_clock() + self._instance_name = hs.get_instance_name() self.is_mine_id = hs.is_mine_id self._event_persist_queue = _EventPeristenceQueue() self._state_resolution_handler = hs.get_state_resolution_handler() @@ -198,7 +199,7 @@ class EventsPersistenceStorage: self, events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, - ) -> int: + ) -> RoomStreamToken: """ Write events to the database Args: @@ -228,11 +229,11 @@ class EventsPersistenceStorage: defer.gatherResults(deferreds, consumeErrors=True) ) - return self.main_store.get_current_events_token() + return RoomStreamToken(None, self.main_store.get_current_events_token()) async def persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False - ) -> Tuple[int, int]: + ) -> Tuple[PersistedEventPosition, RoomStreamToken]: """ Returns: The stream ordering of `event`, and the stream ordering of the @@ -247,7 +248,10 @@ class EventsPersistenceStorage: await make_deferred_yieldable(deferred) max_persisted_id = self.main_store.get_current_events_token() - return (event.internal_metadata.stream_ordering, max_persisted_id) + event_stream_id = event.internal_metadata.stream_ordering + + pos = PersistedEventPosition(self._instance_name, event_stream_id) + return pos, RoomStreamToken(None, max_persisted_id) def _maybe_start_persisting(self, room_id: str): async def persisting_queue(item): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 8c4a83a840..f152f63321 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -25,7 +25,7 @@ RoomsForUser = namedtuple( ) GetRoomsForUserWithStreamOrdering = namedtuple( - "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering") + "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos") ) diff --git a/synapse/types.py b/synapse/types.py index a6fc7df22c..ec39f9e1e8 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -495,6 +495,21 @@ class StreamToken: StreamToken.START = StreamToken.from_string("s0_0") +@attr.s(slots=True, frozen=True) +class PersistedEventPosition: + """Position of a newly persisted event with instance that persisted it. + + This can be used to test whether the event is persisted before or after a + RoomStreamToken. + """ + + instance_name = attr.ib(type=str) + stream = attr.ib(type=int) + + def persisted_after(self, token: RoomStreamToken) -> bool: + return token.stream < self.stream + + class ThirdPartyInstanceID( namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id")) ): diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index bc578411d6..c0ee1cfbd6 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -20,6 +20,7 @@ from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_ from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore from synapse.storage.roommember import RoomsForUser +from synapse.types import PersistedEventPosition from tests.server import FakeTransport @@ -204,10 +205,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) self.replicate() + + expected_pos = PersistedEventPosition( + "master", j2.internal_metadata.stream_ordering + ) self.check( "get_rooms_for_user_with_stream_ordering", (USER_ID_2,), - {(ROOM_ID, j2.internal_metadata.stream_ordering)}, + {(ROOM_ID, expected_pos)}, ) def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): @@ -293,9 +298,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): # the membership change is only any use to us if the room is in the # joined_rooms list. if membership_changes: - self.assertEqual( - joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)} + expected_pos = PersistedEventPosition( + "master", j2.internal_metadata.stream_ordering ) + self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)}) event_id = 0 -- cgit 1.5.1 From f112cfe5bb2c918c9e942941686a05664d8bd7da Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 24 Sep 2020 16:53:51 +0100 Subject: Fix MultiWriteIdGenerator's handling of restarts. (#8374) On startup `MultiWriteIdGenerator` fetches the maximum stream ID for each instance from the table and uses that as its initial "current position" for each writer. This is problematic as a) it involves either a scan of events table or an index (neither of which is ideal), and b) if rows are being persisted out of order elsewhere while the process restarts then using the maximum stream ID is not correct. This could theoretically lead to race conditions where e.g. events that are persisted out of order are not sent down sync streams. We fix this by creating a new table that tracks the current positions of each writer to the stream, and update it each time we finish persisting a new entry. This is a relatively small overhead when persisting events. However for the cache invalidation stream this is a much bigger relative overhead, so instead we note that for invalidation we don't actually care about reliability over restarts (as there's no caches to invalidate) and simply don't bother reading and writing to the new table in that particular case. --- changelog.d/8374.bugfix | 1 + synapse/replication/slave/storage/_base.py | 2 + synapse/storage/databases/main/__init__.py | 8 +- synapse/storage/databases/main/events_worker.py | 4 + .../main/schema/delta/58/18stream_positions.sql | 22 +++ synapse/storage/util/id_generators.py | 148 ++++++++++++++++++--- tests/storage/test_id_generators.py | 119 +++++++++++++++-- 7 files changed, 274 insertions(+), 30 deletions(-) create mode 100644 changelog.d/8374.bugfix create mode 100644 synapse/storage/databases/main/schema/delta/58/18stream_positions.sql (limited to 'tests') diff --git a/changelog.d/8374.bugfix b/changelog.d/8374.bugfix new file mode 100644 index 0000000000..155bc3404f --- /dev/null +++ b/changelog.d/8374.bugfix @@ -0,0 +1 @@ +Fix theoretical race condition where events are not sent down `/sync` if the synchrotron worker is restarted without restarting other workers. diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index d25fa49e1a..d0089fe06c 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -31,11 +31,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, + stream_name="caches", instance_name=hs.get_instance_name(), table="cache_invalidation_stream_by_instance", instance_column="instance_name", id_column="stream_id", sequence_name="cache_invalidation_stream_seq", + writers=[], ) # type: Optional[MultiWriterIdGenerator] else: self._cache_id_gen = None diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index ccb3384db9..0cb12f4c61 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -160,14 +160,20 @@ class DataStore( ) if isinstance(self.database_engine, PostgresEngine): + # We set the `writers` to an empty list here as we don't care about + # missing updates over restarts, as we'll not have anything in our + # caches to invalidate. (This reduces the amount of writes to the DB + # that happen). self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, - instance_name="master", + stream_name="caches", + instance_name=hs.get_instance_name(), table="cache_invalidation_stream_by_instance", instance_column="instance_name", id_column="stream_id", sequence_name="cache_invalidation_stream_seq", + writers=[], ) else: self._cache_id_gen = None diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index de9e8d1dc6..f95679ebc4 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -83,21 +83,25 @@ class EventsWorkerStore(SQLBaseStore): self._stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + stream_name="events", instance_name=hs.get_instance_name(), table="events", instance_column="instance_name", id_column="stream_ordering", sequence_name="events_stream_seq", + writers=hs.config.worker.writers.events, ) self._backfill_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + stream_name="backfill", instance_name=hs.get_instance_name(), table="events", instance_column="instance_name", id_column="stream_ordering", sequence_name="events_backfill_stream_seq", positive=False, + writers=hs.config.worker.writers.events, ) else: # We shouldn't be running in worker mode with SQLite, but its useful diff --git a/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql new file mode 100644 index 0000000000..985fd949a2 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE stream_positions ( + stream_name TEXT NOT NULL, + instance_name TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index b0353ac2dc..727fcc521c 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -22,6 +22,7 @@ from typing import Dict, List, Optional, Set, Union import attr from typing_extensions import Deque +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.util.sequence import PostgresSequenceGenerator @@ -184,12 +185,16 @@ class MultiWriterIdGenerator: Args: db_conn db + stream_name: A name for the stream. instance_name: The name of this instance. table: Database table associated with stream. instance_column: Column that stores the row's writer's instance name id_column: Column that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. + writers: A list of known writers to use to populate current positions + on startup. Can be empty if nothing uses `get_current_token` or + `get_positions` (e.g. caches stream). positive: Whether the IDs are positive (true) or negative (false). When using negative IDs we go backwards from -1 to -2, -3, etc. """ @@ -198,16 +203,20 @@ class MultiWriterIdGenerator: self, db_conn, db: DatabasePool, + stream_name: str, instance_name: str, table: str, instance_column: str, id_column: str, sequence_name: str, + writers: List[str], positive: bool = True, ): self._db = db + self._stream_name = stream_name self._instance_name = instance_name self._positive = positive + self._writers = writers self._return_factor = 1 if positive else -1 # We lock as some functions may be called from DB threads. @@ -216,9 +225,7 @@ class MultiWriterIdGenerator: # Note: If we are a negative stream then we still store all the IDs as # positive to make life easier for us, and simply negate the IDs when we # return them. - self._current_positions = self._load_current_ids( - db_conn, table, instance_column, id_column - ) + self._current_positions = {} # type: Dict[str, int] # Set of local IDs that we're still processing. The current position # should be less than the minimum of this set (if not empty). @@ -251,30 +258,80 @@ class MultiWriterIdGenerator: self._sequence_gen = PostgresSequenceGenerator(sequence_name) + # This goes and fills out the above state from the database. + self._load_current_ids(db_conn, table, instance_column, id_column) + def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str - ) -> Dict[str, int]: - # If positive stream aggregate via MAX. For negative stream use MIN - # *and* negate the result to get a positive number. - sql = """ - SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s - GROUP BY %(instance)s - """ % { - "instance": instance_column, - "id": id_column, - "table": table, - "agg": "MAX" if self._positive else "-MIN", - } - + ): cur = db_conn.cursor() - cur.execute(sql) - # `cur` is an iterable over returned rows, which are 2-tuples. - current_positions = dict(cur) + # Load the current positions of all writers for the stream. + if self._writers: + sql = """ + SELECT instance_name, stream_id FROM stream_positions + WHERE stream_name = ? + """ + sql = self._db.engine.convert_param_style(sql) - cur.close() + cur.execute(sql, (self._stream_name,)) + + self._current_positions = { + instance: stream_id * self._return_factor + for instance, stream_id in cur + if instance in self._writers + } + + # We set the `_persisted_upto_position` to be the minimum of all current + # positions. If empty we use the max stream ID from the DB table. + min_stream_id = min(self._current_positions.values(), default=None) + + if min_stream_id is None: + sql = """ + SELECT COALESCE(%(agg)s(%(id)s), 1) FROM %(table)s + """ % { + "id": id_column, + "table": table, + "agg": "MAX" if self._positive else "-MIN", + } + cur.execute(sql) + (stream_id,) = cur.fetchone() + self._persisted_upto_position = stream_id + else: + # If we have a min_stream_id then we pull out everything greater + # than it from the DB so that we can prefill + # `_known_persisted_positions` and get a more accurate + # `_persisted_upto_position`. + # + # We also check if any of the later rows are from this instance, in + # which case we use that for this instance's current position. This + # is to handle the case where we didn't finish persisting to the + # stream positions table before restart (or the stream position + # table otherwise got out of date). + + sql = """ + SELECT %(instance)s, %(id)s FROM %(table)s + WHERE ? %(cmp)s %(id)s + """ % { + "id": id_column, + "table": table, + "instance": instance_column, + "cmp": "<=" if self._positive else ">=", + } + sql = self._db.engine.convert_param_style(sql) + cur.execute(sql, (min_stream_id,)) + + self._persisted_upto_position = min_stream_id + + with self._lock: + for (instance, stream_id,) in cur: + stream_id = self._return_factor * stream_id + self._add_persisted_position(stream_id) - return current_positions + if instance == self._instance_name: + self._current_positions[instance] = stream_id + + cur.close() def _load_next_id_txn(self, txn) -> int: return self._sequence_gen.get_next_id_txn(txn) @@ -316,6 +373,21 @@ class MultiWriterIdGenerator: txn.call_after(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id) + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persited row with the correct instance name. + if self._writers: + txn.call_after( + run_as_background_process, + "MultiWriterIdGenerator._update_table", + self._db.runInteraction, + "MultiWriterIdGenerator._update_table", + self._update_stream_positions_table_txn, + ) + return self._return_factor * next_id def _mark_id_as_finished(self, next_id: int): @@ -447,6 +519,28 @@ class MultiWriterIdGenerator: # do. break + def _update_stream_positions_table_txn(self, txn): + """Update the `stream_positions` table with newly persisted position. + """ + + if not self._writers: + return + + # We upsert the value, ensuring on conflict that we always increase the + # value (or decrease if stream goes backwards). + sql = """ + INSERT INTO stream_positions (stream_name, instance_name, stream_id) + VALUES (?, ?, ?) + ON CONFLICT (stream_name, instance_name) + DO UPDATE SET + stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id) + """ % { + "agg": "GREATEST" if self._positive else "LEAST", + } + + pos = (self.get_current_token_for_writer(self._instance_name),) + txn.execute(sql, (self._stream_name, self._instance_name, pos)) + @attr.s(slots=True) class _AsyncCtxManagerWrapper: @@ -503,4 +597,16 @@ class _MultiWriterCtxManager: if exc_type is not None: return False + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persisted row with the correct instance name. + if self.id_gen._writers: + await self.id_gen._db.runInteraction( + "MultiWriterIdGenerator._update_table", + self.id_gen._update_stream_positions_table_txn, + ) + return False diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index fb8f5bc255..d4ff55fbff 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -43,16 +43,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): """ ) - def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create_id_generator( + self, instance_name="master", writers=["master"] + ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( conn, self.db_pool, + stream_name="test_stream", instance_name=instance_name, table="foobar", instance_column="instance_name", id_column="stream_id", sequence_name="foobar_seq", + writers=writers, ) return self.get_success(self.db_pool.runWithConnection(_create)) @@ -68,6 +72,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", (instance_name,), ) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, lastval()) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval() + """, + (instance_name,), + ) self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) @@ -81,6 +92,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), ) txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,)) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, ?) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? + """, + (instance_name, stream_id, stream_id), + ) self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) @@ -179,8 +197,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_rows("first", 3) self._insert_rows("second", 4) - first_id_gen = self._create_id_generator("first") - second_id_gen = self._create_id_generator("second") + first_id_gen = self._create_id_generator("first", writers=["first", "second"]) + second_id_gen = self._create_id_generator("second", writers=["first", "second"]) self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) @@ -262,7 +280,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("first", 3) self._insert_row_with_id("second", 5) - id_gen = self._create_id_generator("first") + id_gen = self._create_id_generator("first", writers=["first", "second"]) self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) @@ -300,7 +318,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("first", 3) self._insert_row_with_id("second", 5) - id_gen = self._create_id_generator("first") + id_gen = self._create_id_generator("first", writers=["first", "second"]) self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) @@ -319,6 +337,80 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # `persisted_upto_position` in this case, then it will be correct in the # other cases that are tested above (since they'll hit the same code). + def test_restart_during_out_of_order_persistence(self): + """Test that restarting a process while another process is writing out + of order updates are handled correctly. + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + # Persist two rows at once + ctx1 = self.get_success(id_gen.get_next()) + ctx2 = self.get_success(id_gen.get_next()) + + s1 = self.get_success(ctx1.__aenter__()) + s2 = self.get_success(ctx2.__aenter__()) + + self.assertEqual(s1, 8) + self.assertEqual(s2, 9) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + # We finish persisting the second row before restart + self.get_success(ctx2.__aexit__(None, None, None)) + + # We simulate a restart of another worker by just creating a new ID gen. + id_gen_worker = self._create_id_generator("worker") + + # Restarted worker should not see the second persisted row + self.assertEqual(id_gen_worker.get_positions(), {"master": 7}) + self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7) + + # Now if we persist the first row then both instances should jump ahead + # correctly. + self.get_success(ctx1.__aexit__(None, None, None)) + + self.assertEqual(id_gen.get_positions(), {"master": 9}) + id_gen_worker.advance("master", 9) + self.assertEqual(id_gen_worker.get_positions(), {"master": 9}) + + def test_writer_config_change(self): + """Test that changing the writer config correctly works. + """ + + self._insert_row_with_id("first", 3) + self._insert_row_with_id("second", 5) + + # Initial config has two writers + id_gen = self._create_id_generator("first", writers=["first", "second"]) + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + # New config removes one of the configs. Note that if the writer is + # removed from config we assume that it has been shut down and has + # finished persisting, hence why the persisted upto position is 5. + id_gen_2 = self._create_id_generator("second", writers=["second"]) + self.assertEqual(id_gen_2.get_persisted_upto_position(), 5) + + # This config points to a single, previously unused writer. + id_gen_3 = self._create_id_generator("third", writers=["third"]) + self.assertEqual(id_gen_3.get_persisted_upto_position(), 5) + + # Check that we get a sane next stream ID with this new config. + + async def _get_next_async(): + async with id_gen_3.get_next() as stream_id: + self.assertEqual(stream_id, 6) + + self.get_success(_get_next_async()) + self.assertEqual(id_gen_3.get_persisted_upto_position(), 6) + class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """Tests MultiWriterIdGenerator that produce *negative* stream IDs. @@ -345,16 +437,20 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """ ) - def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create_id_generator( + self, instance_name="master", writers=["master"] + ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( conn, self.db_pool, + stream_name="test_stream", instance_name=instance_name, table="foobar", instance_column="instance_name", id_column="stream_id", sequence_name="foobar_seq", + writers=writers, positive=False, ) @@ -368,6 +464,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): txn.execute( "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), ) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, ?) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? + """, + (instance_name, -stream_id, -stream_id), + ) self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) @@ -409,8 +512,8 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """Tests that having multiple instances that get advanced over federation works corretly. """ - id_gen_1 = self._create_id_generator("first") - id_gen_2 = self._create_id_generator("second") + id_gen_1 = self._create_id_generator("first", writers=["first", "second"]) + id_gen_2 = self._create_id_generator("second", writers=["first", "second"]) async def _get_next_async(): async with id_gen_1.get_next() as stream_id: -- cgit 1.5.1 From abd04b6af0671517a01781c8bd10fef2a6c32cc4 Mon Sep 17 00:00:00 2001 From: Tdxdxoz Date: Fri, 25 Sep 2020 19:01:45 +0800 Subject: Allow existing users to login via OpenID Connect. (#8345) Co-authored-by: Benjamin Koch This adds configuration flags that will match a user to pre-existing users when logging in via OpenID Connect. This is useful when switching to an existing SSO system. --- changelog.d/8345.feature | 1 + docs/sample_config.yaml | 5 +++ synapse/config/oidc_config.py | 6 ++++ synapse/handlers/oidc_handler.py | 42 +++++++++++++++++--------- synapse/storage/databases/main/registration.py | 4 +-- tests/handlers/test_oidc.py | 35 +++++++++++++++++++++ 6 files changed, 76 insertions(+), 17 deletions(-) create mode 100644 changelog.d/8345.feature (limited to 'tests') diff --git a/changelog.d/8345.feature b/changelog.d/8345.feature new file mode 100644 index 0000000000..4ee5b6a56e --- /dev/null +++ b/changelog.d/8345.feature @@ -0,0 +1 @@ +Add a configuration option that allows existing users to log in with OpenID Connect. Contributed by @BBBSnowball and @OmmyZhang. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index fb04ff283d..845f537795 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1689,6 +1689,11 @@ oidc_config: # #skip_verification: true + # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead + # of failing. This could be used if switching from password logins to OIDC. Defaults to false. + # + #allow_existing_users: true + # An external module can be provided here as a custom solution to mapping # attributes returned from a OIDC provider onto a matrix user. # diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index e0939bce84..70fc8a2f62 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -56,6 +56,7 @@ class OIDCConfig(Config): self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint") self.oidc_jwks_uri = oidc_config.get("jwks_uri") self.oidc_skip_verification = oidc_config.get("skip_verification", False) + self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False) ump_config = oidc_config.get("user_mapping_provider", {}) ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) @@ -158,6 +159,11 @@ class OIDCConfig(Config): # #skip_verification: true + # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead + # of failing. This could be used if switching from password logins to OIDC. Defaults to false. + # + #allow_existing_users: true + # An external module can be provided here as a custom solution to mapping # attributes returned from a OIDC provider onto a matrix user. # diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 4230dbaf99..0e06e4408d 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -114,6 +114,7 @@ class OidcHandler: hs.config.oidc_user_mapping_provider_config ) # type: OidcMappingProvider self._skip_verification = hs.config.oidc_skip_verification # type: bool + self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool self._http_client = hs.get_proxied_http_client() self._auth_handler = hs.get_auth_handler() @@ -849,7 +850,8 @@ class OidcHandler: If we don't find the user that way, we should register the user, mapping the localpart and the display name from the UserInfo. - If a user already exists with the mxid we've mapped, raise an exception. + If a user already exists with the mxid we've mapped and allow_existing_users + is disabled, raise an exception. Args: userinfo: an object representing the user @@ -905,21 +907,31 @@ class OidcHandler: localpart = map_username_to_mxid_localpart(attributes["localpart"]) - user_id = UserID(localpart, self._hostname) - if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()): - # This mxid is taken - raise MappingException( - "mxid '{}' is already taken".format(user_id.to_string()) + user_id = UserID(localpart, self._hostname).to_string() + users = await self._datastore.get_users_by_id_case_insensitive(user_id) + if users: + if self._allow_existing_users: + if len(users) == 1: + registered_user_id = next(iter(users)) + elif user_id in users: + registered_user_id = user_id + else: + raise MappingException( + "Attempted to login as '{}' but it matches more than one user inexactly: {}".format( + user_id, list(users.keys()) + ) + ) + else: + # This mxid is taken + raise MappingException("mxid '{}' is already taken".format(user_id)) + else: + # It's the first time this user is logging in and the mapped mxid was + # not taken, register the user + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, + default_display_name=attributes["display_name"], + user_agent_ips=(user_agent, ip_address), ) - - # It's the first time this user is logging in and the mapped mxid was - # not taken, register the user - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=attributes["display_name"], - user_agent_ips=(user_agent, ip_address), - ) - await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id, ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 33825e8949..48ce7ecd16 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -393,7 +393,7 @@ class RegistrationWorkerStore(SQLBaseStore): async def get_user_by_external_id( self, auth_provider: str, external_id: str - ) -> str: + ) -> Optional[str]: """Look up a user by their external auth id Args: @@ -401,7 +401,7 @@ class RegistrationWorkerStore(SQLBaseStore): external_id: id on that system Returns: - str|None: the mxid of the user, or None if they are not known + the mxid of the user, or None if they are not known """ return await self.db_pool.simple_select_one_onecol( table="user_external_ids", diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 89ec5fcb31..5910772aa8 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -617,3 +617,38 @@ class OidcHandlerTestCase(HomeserverTestCase): ) ) self.assertEqual(mxid, "@test_user_2:test") + + # Test if the mxid is already taken + store = self.hs.get_datastore() + user3 = UserID.from_string("@test_user_3:test") + self.get_success( + store.register_user(user_id=user3.to_string(), password_hash=None) + ) + userinfo = {"sub": "test3", "username": "test_user_3"} + e = self.get_failure( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ), + MappingException, + ) + self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken") + + @override_config({"oidc_config": {"allow_existing_users": True}}) + def test_map_userinfo_to_existing_user(self): + """Existing users can log in with OpenID Connect when allow_existing_users is True.""" + store = self.hs.get_datastore() + user4 = UserID.from_string("@test_user_4:test") + self.get_success( + store.register_user(user_id=user4.to_string(), password_hash=None) + ) + userinfo = { + "sub": "test4", + "username": "test_user_4", + } + token = {} + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user_4:test") -- cgit 1.5.1 From fec6f9ac178867a8e7c5410e0d25898f29bab35c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 25 Sep 2020 12:29:54 +0100 Subject: Fix occasional "Re-starting finished log context" from keyring (#8398) * Fix test_verify_json_objects_for_server_awaits_previous_requests It turns out that this wasn't really testing what it thought it was testing (in particular, `check_context` was turning failures into success, which was making the tests pass even though it wasn't clear they should have been. It was also somewhat overcomplex - we can test what it was trying to test without mocking out perspectives servers. * Fix warnings about finished logcontexts in the keyring We need to make sure that we finish the key fetching magic before we run the verifying code, to ensure that we don't mess up our logcontexts. --- changelog.d/8398.bugfix | 1 + synapse/crypto/keyring.py | 70 +++++++++++++++---------- tests/crypto/test_keyring.py | 120 ++++++++++++++++++++----------------------- 3 files changed, 101 insertions(+), 90 deletions(-) create mode 100644 changelog.d/8398.bugfix (limited to 'tests') diff --git a/changelog.d/8398.bugfix b/changelog.d/8398.bugfix new file mode 100644 index 0000000000..e432aeebf1 --- /dev/null +++ b/changelog.d/8398.bugfix @@ -0,0 +1 @@ +Fix "Re-starting finished log context" warning when receiving an event we already had over federation. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 42e4087a92..c04ad77cf9 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -42,7 +42,6 @@ from synapse.api.errors import ( ) from synapse.logging.context import ( PreserveLoggingContext, - current_context, make_deferred_yieldable, preserve_fn, run_in_background, @@ -233,8 +232,6 @@ class Keyring: """ try: - ctx = current_context() - # map from server name to a set of outstanding request ids server_to_request_ids = {} @@ -265,12 +262,8 @@ class Keyring: # if there are no more requests for this server, we can drop the lock. if not server_requests: - with PreserveLoggingContext(ctx): - logger.debug("Releasing key lookup lock on %s", server_name) - - # ... but not immediately, as that can cause stack explosions if - # we get a long queue of lookups. - self.clock.call_later(0, drop_server_lock, server_name) + logger.debug("Releasing key lookup lock on %s", server_name) + drop_server_lock(server_name) return res @@ -335,20 +328,32 @@ class Keyring: ) # look for any requests which weren't satisfied - with PreserveLoggingContext(): - for verify_request in remaining_requests: - verify_request.key_ready.errback( - SynapseError( - 401, - "No key for %s with ids in %s (min_validity %i)" - % ( - verify_request.server_name, - verify_request.key_ids, - verify_request.minimum_valid_until_ts, - ), - Codes.UNAUTHORIZED, - ) + while remaining_requests: + verify_request = remaining_requests.pop() + rq_str = ( + "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)" + % ( + verify_request.server_name, + verify_request.key_ids, + verify_request.minimum_valid_until_ts, ) + ) + + # If we run the errback immediately, it may cancel our + # loggingcontext while we are still in it, so instead we + # schedule it for the next time round the reactor. + # + # (this also ensures that we don't get a stack overflow if we + # has a massive queue of lookups waiting for this server). + self.clock.call_later( + 0, + verify_request.key_ready.errback, + SynapseError( + 401, + "Failed to find any key to satisfy %s" % (rq_str,), + Codes.UNAUTHORIZED, + ), + ) except Exception as err: # we don't really expect to get here, because any errors should already # have been caught and logged. But if we do, let's log the error and make @@ -410,10 +415,23 @@ class Keyring: # key was not valid at this point continue - with PreserveLoggingContext(): - verify_request.key_ready.callback( - (server_name, key_id, fetch_key_result.verify_key) - ) + # we have a valid key for this request. If we run the callback + # immediately, it may cancel our loggingcontext while we are still in + # it, so instead we schedule it for the next time round the reactor. + # + # (this also ensures that we don't get a stack overflow if we had + # a massive queue of lookups waiting for this server). + logger.debug( + "Found key %s:%s for %s", + server_name, + key_id, + verify_request.request_name, + ) + self.clock.call_later( + 0, + verify_request.key_ready.callback, + (server_name, key_id, fetch_key_result.verify_key), + ) completed.append(verify_request) break diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 2e6e7abf1f..5cf408f21f 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -23,6 +23,7 @@ from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key from twisted.internet import defer +from twisted.internet.defer import Deferred, ensureDeferred from synapse.api.errors import SynapseError from synapse.crypto import keyring @@ -33,7 +34,6 @@ from synapse.crypto.keyring import ( ) from synapse.logging.context import ( LoggingContext, - PreserveLoggingContext, current_context, make_deferred_yieldable, ) @@ -68,54 +68,40 @@ class MockPerspectiveServer: class KeyringTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): - self.mock_perspective_server = MockPerspectiveServer() - self.http_client = Mock() - - config = self.default_config() - config["trusted_key_servers"] = [ - { - "server_name": self.mock_perspective_server.server_name, - "verify_keys": self.mock_perspective_server.get_verify_keys(), - } - ] - - return self.setup_test_homeserver( - handlers=None, http_client=self.http_client, config=config - ) - - def check_context(self, _, expected): + def check_context(self, val, expected): self.assertEquals(getattr(current_context(), "request", None), expected) + return val def test_verify_json_objects_for_server_awaits_previous_requests(self): - key1 = signedjson.key.generate_signing_key(1) + mock_fetcher = keyring.KeyFetcher() + mock_fetcher.get_keys = Mock() + kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) - kr = keyring.Keyring(self.hs) + # a signed object that we are going to try to validate + key1 = signedjson.key.generate_signing_key(1) json1 = {} signedjson.sign.sign_json(json1, "server10", key1) - persp_resp = { - "server_keys": [ - self.mock_perspective_server.get_signed_key( - "server10", signedjson.key.get_verify_key(key1) - ) - ] - } - persp_deferred = defer.Deferred() + # start off a first set of lookups. We make the mock fetcher block until this + # deferred completes. + first_lookup_deferred = Deferred() + + async def first_lookup_fetch(keys_to_fetch): + self.assertEquals(current_context().request, "context_11") + self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}}) - async def get_perspectives(**kwargs): - self.assertEquals(current_context().request, "11") - with PreserveLoggingContext(): - await persp_deferred - return persp_resp + await make_deferred_yieldable(first_lookup_deferred) + return { + "server10": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) + } + } - self.http_client.post_json.side_effect = get_perspectives + mock_fetcher.get_keys.side_effect = first_lookup_fetch - # start off a first set of lookups - @defer.inlineCallbacks - def first_lookup(): - with LoggingContext("11") as context_11: - context_11.request = "11" + async def first_lookup(): + with LoggingContext("context_11") as context_11: + context_11.request = "context_11" res_deferreds = kr.verify_json_objects_for_server( [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")] @@ -124,7 +110,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): # the unsigned json should be rejected pretty quickly self.assertTrue(res_deferreds[1].called) try: - yield res_deferreds[1] + await res_deferreds[1] self.assertFalse("unsigned json didn't cause a failure") except SynapseError: pass @@ -132,45 +118,51 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.assertFalse(res_deferreds[0].called) res_deferreds[0].addBoth(self.check_context, None) - yield make_deferred_yieldable(res_deferreds[0]) + await make_deferred_yieldable(res_deferreds[0]) - # let verify_json_objects_for_server finish its work before we kill the - # logcontext - yield self.clock.sleep(0) + d0 = ensureDeferred(first_lookup()) - d0 = first_lookup() - - # wait a tick for it to send the request to the perspectives server - # (it first tries the datastore) - self.pump() - self.http_client.post_json.assert_called_once() + mock_fetcher.get_keys.assert_called_once() # a second request for a server with outstanding requests # should block rather than start a second call - @defer.inlineCallbacks - def second_lookup(): - with LoggingContext("12") as context_12: - context_12.request = "12" - self.http_client.post_json.reset_mock() - self.http_client.post_json.return_value = defer.Deferred() + + async def second_lookup_fetch(keys_to_fetch): + self.assertEquals(current_context().request, "context_12") + return { + "server10": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) + } + } + + mock_fetcher.get_keys.reset_mock() + mock_fetcher.get_keys.side_effect = second_lookup_fetch + second_lookup_state = [0] + + async def second_lookup(): + with LoggingContext("context_12") as context_12: + context_12.request = "context_12" res_deferreds_2 = kr.verify_json_objects_for_server( [("server10", json1, 0, "test")] ) res_deferreds_2[0].addBoth(self.check_context, None) - yield make_deferred_yieldable(res_deferreds_2[0]) + second_lookup_state[0] = 1 + await make_deferred_yieldable(res_deferreds_2[0]) + second_lookup_state[0] = 2 - # let verify_json_objects_for_server finish its work before we kill the - # logcontext - yield self.clock.sleep(0) - - d2 = second_lookup() + d2 = ensureDeferred(second_lookup()) self.pump() - self.http_client.post_json.assert_not_called() + # the second request should be pending, but the fetcher should not yet have been + # called + self.assertEqual(second_lookup_state[0], 1) + mock_fetcher.get_keys.assert_not_called() # complete the first request - persp_deferred.callback(persp_resp) + first_lookup_deferred.callback(None) + + # and now both verifications should succeed. self.get_success(d0) self.get_success(d2) -- cgit 1.5.1 From 5e3ca12b158b4abefe2e3a54259ab5255dca93d8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 28 Sep 2020 17:58:33 +0100 Subject: Create a mechanism for marking tests "logcontext clean" (#8399) --- changelog.d/8399.misc | 1 + synapse/logging/context.py | 43 +++++++++++++++++++++++-------------------- tests/crypto/test_keyring.py | 3 +++ tests/unittest.py | 15 ++++++++++++++- 4 files changed, 41 insertions(+), 21 deletions(-) create mode 100644 changelog.d/8399.misc (limited to 'tests') diff --git a/changelog.d/8399.misc b/changelog.d/8399.misc new file mode 100644 index 0000000000..ce6e8123cf --- /dev/null +++ b/changelog.d/8399.misc @@ -0,0 +1 @@ +Create a mechanism for marking tests "logcontext clean". diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 2e282d9d67..ca0c774cc5 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -65,6 +65,11 @@ except Exception: return None +# a hook which can be set during testing to assert that we aren't abusing logcontexts. +def logcontext_error(msg: str): + logger.warning(msg) + + # get an id for the current thread. # # threading.get_ident doesn't actually return an OS-level tid, and annoyingly, @@ -330,10 +335,9 @@ class LoggingContext: """Enters this logging context into thread local storage""" old_context = set_current_context(self) if self.previous_context != old_context: - logger.warning( - "Expected previous context %r, found %r", - self.previous_context, - old_context, + logcontext_error( + "Expected previous context %r, found %r" + % (self.previous_context, old_context,) ) return self @@ -346,10 +350,10 @@ class LoggingContext: current = set_current_context(self.previous_context) if current is not self: if current is SENTINEL_CONTEXT: - logger.warning("Expected logging context %s was lost", self) + logcontext_error("Expected logging context %s was lost" % (self,)) else: - logger.warning( - "Expected logging context %s but found %s", self, current + logcontext_error( + "Expected logging context %s but found %s" % (self, current) ) # the fact that we are here suggests that the caller thinks that everything @@ -387,16 +391,16 @@ class LoggingContext: support getrusuage. """ if get_thread_id() != self.main_thread: - logger.warning("Started logcontext %s on different thread", self) + logcontext_error("Started logcontext %s on different thread" % (self,)) return if self.finished: - logger.warning("Re-starting finished log context %s", self) + logcontext_error("Re-starting finished log context %s" % (self,)) # If we haven't already started record the thread resource usage so # far if self.usage_start: - logger.warning("Re-starting already-active log context %s", self) + logcontext_error("Re-starting already-active log context %s" % (self,)) else: self.usage_start = rusage @@ -414,7 +418,7 @@ class LoggingContext: try: if get_thread_id() != self.main_thread: - logger.warning("Stopped logcontext %s on different thread", self) + logcontext_error("Stopped logcontext %s on different thread" % (self,)) return if not rusage: @@ -422,9 +426,9 @@ class LoggingContext: # Record the cpu used since we started if not self.usage_start: - logger.warning( - "Called stop on logcontext %s without recording a start rusage", - self, + logcontext_error( + "Called stop on logcontext %s without recording a start rusage" + % (self,) ) return @@ -584,14 +588,13 @@ class PreserveLoggingContext: if context != self._new_context: if not context: - logger.warning( - "Expected logging context %s was lost", self._new_context + logcontext_error( + "Expected logging context %s was lost" % (self._new_context,) ) else: - logger.warning( - "Expected logging context %s but found %s", - self._new_context, - context, + logcontext_error( + "Expected logging context %s but found %s" + % (self._new_context, context,) ) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 5cf408f21f..8ff1460c0d 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -41,6 +41,7 @@ from synapse.storage.keys import FetchKeyResult from tests import unittest from tests.test_utils import make_awaitable +from tests.unittest import logcontext_clean class MockPerspectiveServer: @@ -67,6 +68,7 @@ class MockPerspectiveServer: signedjson.sign.sign_json(res, self.server_name, self.key) +@logcontext_clean class KeyringTestCase(unittest.HomeserverTestCase): def check_context(self, val, expected): self.assertEquals(getattr(current_context(), "request", None), expected) @@ -309,6 +311,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): mock_fetcher2.get_keys.assert_called_once() +@logcontext_clean class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.http_client = Mock() diff --git a/tests/unittest.py b/tests/unittest.py index dabf69cff4..bbe50c3851 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -23,7 +23,7 @@ import logging import time from typing import Optional, Tuple, Type, TypeVar, Union -from mock import Mock +from mock import Mock, patch from canonicaljson import json @@ -169,6 +169,19 @@ def INFO(target): return target +def logcontext_clean(target): + """A decorator which marks the TestCase or method as 'logcontext_clean' + + ... ie, any logcontext errors should cause a test failure + """ + + def logcontext_error(msg): + raise AssertionError("logcontext error: %s" % (msg)) + + patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error) + return patcher(target) + + class HomeserverTestCase(TestCase): """ A base TestCase that reduces boilerplate for HomeServer-using test cases. -- cgit 1.5.1 From bd380d942fdf91cf1214d6859f2bc97d12a92ab4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 28 Sep 2020 18:00:30 +0100 Subject: Add checks for postgres sequence consistency (#8402) --- changelog.d/8402.misc | 1 + docs/postgres.md | 11 ++++ synapse/storage/databases/main/registration.py | 3 + synapse/storage/databases/state/store.py | 3 + synapse/storage/util/id_generators.py | 5 ++ synapse/storage/util/sequence.py | 90 +++++++++++++++++++++++++- tests/storage/test_id_generators.py | 22 ++++++- tests/unittest.py | 31 ++++++++- 8 files changed, 160 insertions(+), 6 deletions(-) create mode 100644 changelog.d/8402.misc (limited to 'tests') diff --git a/changelog.d/8402.misc b/changelog.d/8402.misc new file mode 100644 index 0000000000..ad1804d207 --- /dev/null +++ b/changelog.d/8402.misc @@ -0,0 +1 @@ +Add checks on startup that PostgreSQL sequences are consistent with their associated tables. diff --git a/docs/postgres.md b/docs/postgres.md index e71a1975d8..c30cc1fd8c 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -106,6 +106,17 @@ Note that the above may fail with an error about duplicate rows if corruption has already occurred, and such duplicate rows will need to be manually removed. +## Fixing inconsistent sequences error + +Synapse uses Postgres sequences to generate IDs for various tables. A sequence +and associated table can get out of sync if, for example, Synapse has been +downgraded and then upgraded again. + +To fix the issue shut down Synapse (including any and all workers) and run the +SQL command included in the error message. Once done Synapse should start +successfully. + + ## Tuning Postgres The default settings should be fine for most deployments. For larger diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 48ce7ecd16..a83df7759d 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -41,6 +41,9 @@ class RegistrationWorkerStore(SQLBaseStore): self.config = hs.config self.clock = hs.get_clock() + # Note: we don't check this sequence for consistency as we'd have to + # call `find_max_generated_user_id_localpart` each time, which is + # expensive if there are many entries. self._user_id_seq = build_sequence_generator( database.engine, find_max_generated_user_id_localpart, "user_id_seq", ) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index bec3780a32..989f0cbc9d 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -99,6 +99,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self._state_group_seq_gen = build_sequence_generator( self.database_engine, get_max_state_group_txn, "state_group_id_seq" ) + self._state_group_seq_gen.check_consistency( + db_conn, table="state_groups", id_column="id" + ) @cached(max_entries=10000, iterable=True) async def get_state_group_delta(self, state_group): diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 4269eaf918..4fd7573e26 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -258,6 +258,11 @@ class MultiWriterIdGenerator: self._sequence_gen = PostgresSequenceGenerator(sequence_name) + # We check that the table and sequence haven't diverged. + self._sequence_gen.check_consistency( + db_conn, table=table, id_column=id_column, positive=positive + ) + # This goes and fills out the above state from the database. self._load_current_ids(db_conn, table, instance_column, id_column) diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index ffc1894748..2dd95e2709 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -13,11 +13,34 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import logging import threading from typing import Callable, List, Optional -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -from synapse.storage.types import Cursor +from synapse.storage.engines import ( + BaseDatabaseEngine, + IncorrectDatabaseSetup, + PostgresEngine, +) +from synapse.storage.types import Connection, Cursor + +logger = logging.getLogger(__name__) + + +_INCONSISTENT_SEQUENCE_ERROR = """ +Postgres sequence '%(seq)s' is inconsistent with associated +table '%(table)s'. This can happen if Synapse has been downgraded and +then upgraded again, or due to a bad migration. + +To fix this error, shut down Synapse (including any and all workers) +and run the following SQL: + + SELECT setval('%(seq)s', ( + %(max_id_sql)s + )); + +See docs/postgres.md for more information. +""" class SequenceGenerator(metaclass=abc.ABCMeta): @@ -28,6 +51,19 @@ class SequenceGenerator(metaclass=abc.ABCMeta): """Gets the next ID in the sequence""" ... + @abc.abstractmethod + def check_consistency( + self, db_conn: Connection, table: str, id_column: str, positive: bool = True + ): + """Should be called during start up to test that the current value of + the sequence is greater than or equal to the maximum ID in the table. + + This is to handle various cases where the sequence value can get out + of sync with the table, e.g. if Synapse gets rolled back to a previous + version and the rolled forwards again. + """ + ... + class PostgresSequenceGenerator(SequenceGenerator): """An implementation of SequenceGenerator which uses a postgres sequence""" @@ -45,6 +81,50 @@ class PostgresSequenceGenerator(SequenceGenerator): ) return [i for (i,) in txn] + def check_consistency( + self, db_conn: Connection, table: str, id_column: str, positive: bool = True + ): + txn = db_conn.cursor() + + # First we get the current max ID from the table. + table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % { + "id": id_column, + "table": table, + "agg": "MAX" if positive else "-MIN", + } + + txn.execute(table_sql) + row = txn.fetchone() + if not row: + # Table is empty, so nothing to do. + txn.close() + return + + # Now we fetch the current value from the sequence and compare with the + # above. + max_stream_id = row[0] + txn.execute( + "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name} + ) + last_value, is_called = txn.fetchone() + txn.close() + + # If `is_called` is False then `last_value` is actually the value that + # will be generated next, so we decrement to get the true "last value". + if not is_called: + last_value -= 1 + + if max_stream_id > last_value: + logger.warning( + "Postgres sequence %s is behind table %s: %d < %d", + last_value, + max_stream_id, + ) + raise IncorrectDatabaseSetup( + _INCONSISTENT_SEQUENCE_ERROR + % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql} + ) + GetFirstCallbackType = Callable[[Cursor], int] @@ -81,6 +161,12 @@ class LocalSequenceGenerator(SequenceGenerator): self._current_max_id += 1 return self._current_max_id + def check_consistency( + self, db_conn: Connection, table: str, id_column: str, positive: bool = True + ): + # There is nothing to do for in memory sequences + pass + def build_sequence_generator( database_engine: BaseDatabaseEngine, diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index d4ff55fbff..4558bee7be 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -12,9 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from synapse.storage.database import DatabasePool +from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.util.id_generators import MultiWriterIdGenerator from tests.unittest import HomeserverTestCase @@ -59,7 +58,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): writers=writers, ) - return self.get_success(self.db_pool.runWithConnection(_create)) + return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) def _insert_rows(self, instance_name: str, number: int): """Insert N rows as the given instance, inserting with stream IDs pulled @@ -411,6 +410,23 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(_get_next_async()) self.assertEqual(id_gen_3.get_persisted_upto_position(), 6) + def test_sequence_consistency(self): + """Test that we error out if the table and sequence diverges. + """ + + # Prefill with some rows + self._insert_row_with_id("master", 3) + + # Now we add a row *without* updating the stream ID + def _insert(txn): + txn.execute("INSERT INTO foobar VALUES (26, 'master')") + + self.get_success(self.db_pool.runInteraction("_insert", _insert)) + + # Creating the ID gen should error + with self.assertRaises(IncorrectDatabaseSetup): + self._create_id_generator("first") + class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """Tests MultiWriterIdGenerator that produce *negative* stream IDs. diff --git a/tests/unittest.py b/tests/unittest.py index bbe50c3851..e654c0442d 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import gc import hashlib import hmac @@ -28,6 +27,7 @@ from mock import Mock, patch from canonicaljson import json from twisted.internet.defer import Deferred, ensureDeferred, succeed +from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool from twisted.trial import unittest @@ -476,6 +476,35 @@ class HomeserverTestCase(TestCase): self.pump() return self.failureResultOf(d, exc) + def get_success_or_raise(self, d, by=0.0): + """Drive deferred to completion and return result or raise exception + on failure. + """ + + if inspect.isawaitable(d): + deferred = ensureDeferred(d) + if not isinstance(deferred, Deferred): + return d + + results = [] # type: list + deferred.addBoth(results.append) + + self.pump(by=by) + + if not results: + self.fail( + "Success result expected on {!r}, found no result instead".format( + deferred + ) + ) + + result = results[0] + + if isinstance(result, Failure): + result.raiseException() + + return result + def register_user(self, username, password, admin=False): """ Register a user. Requires the Admin API be registered. -- cgit 1.5.1 From 1c262431f9bf768d106bf79a568479fa5a0784a1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 29 Sep 2020 10:29:21 +0100 Subject: Fix handling of connection timeouts in outgoing http requests (#8400) * Remove `on_timeout_cancel` from `timeout_deferred` The `on_timeout_cancel` param to `timeout_deferred` wasn't always called on a timeout (in particular if the canceller raised an exception), so it was unreliable. It was also only used in one place, and to be honest it's easier to do what it does a different way. * Fix handling of connection timeouts in outgoing http requests Turns out that if we get a timeout during connection, then a different exception is raised, which wasn't always handled correctly. To fix it, catch the exception in SimpleHttpClient and turn it into a RequestTimedOutError (which is already a documented exception). Also add a description to RequestTimedOutError so that we can see which stage it failed at. * Fix incorrect handling of timeouts reading federation responses This was trapping the wrong sort of TimeoutError, so was never being hit. The effect was relatively minor, but we should fix this so that it does the expected thing. * Fix inconsistent handling of `timeout` param between methods `get_json`, `put_json` and `delete_json` were applying a different timeout to the response body to `post_json`; bring them in line and test. Co-authored-by: Patrick Cloke Co-authored-by: Erik Johnston --- changelog.d/8400.bugfix | 1 + synapse/handlers/identity.py | 25 +++-- synapse/http/__init__.py | 17 +--- synapse/http/client.py | 54 ++++++---- synapse/http/matrixfederationclient.py | 55 +++++++--- synapse/http/proxyagent.py | 16 ++- synapse/util/async_helpers.py | 47 ++++----- tests/http/test_fedclient.py | 14 +-- tests/http/test_simple_client.py | 180 +++++++++++++++++++++++++++++++++ 9 files changed, 311 insertions(+), 98 deletions(-) create mode 100644 changelog.d/8400.bugfix create mode 100644 tests/http/test_simple_client.py (limited to 'tests') diff --git a/changelog.d/8400.bugfix b/changelog.d/8400.bugfix new file mode 100644 index 0000000000..835658ba5e --- /dev/null +++ b/changelog.d/8400.bugfix @@ -0,0 +1 @@ +Fix incorrect handling of timeouts on outgoing HTTP requests. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index ab15570f7a..bc3e9607ca 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -21,8 +21,6 @@ import logging import urllib.parse from typing import Awaitable, Callable, Dict, List, Optional, Tuple -from twisted.internet.error import TimeoutError - from synapse.api.errors import ( CodeMessageException, Codes, @@ -30,6 +28,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.config.emailconfig import ThreepidBehaviour +from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient from synapse.types import JsonDict, Requester from synapse.util import json_decoder @@ -93,7 +92,7 @@ class IdentityHandler(BaseHandler): try: data = await self.http_client.get_json(url, query_params) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: logger.info( @@ -173,7 +172,7 @@ class IdentityHandler(BaseHandler): if e.code != 404 or not use_v2: logger.error("3PID bind failed with Matrix error: %r", e) raise e.to_synapse_error() - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except CodeMessageException as e: data = json_decoder.decode(e.msg) # XXX WAT? @@ -273,7 +272,7 @@ class IdentityHandler(BaseHandler): else: logger.error("Failed to unbind threepid on identity server: %s", e) raise SynapseError(500, "Failed to contact identity server") - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") await self.store.remove_user_bound_threepid( @@ -419,7 +418,7 @@ class IdentityHandler(BaseHandler): except HttpResponseException as e: logger.info("Proxied requestToken failed: %r", e) raise e.to_synapse_error() - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") async def requestMsisdnToken( @@ -471,7 +470,7 @@ class IdentityHandler(BaseHandler): except HttpResponseException as e: logger.info("Proxied requestToken failed: %r", e) raise e.to_synapse_error() - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") assert self.hs.config.public_baseurl @@ -553,7 +552,7 @@ class IdentityHandler(BaseHandler): id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken", body, ) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: logger.warning("Error contacting msisdn account_threepid_delegate: %s", e) @@ -627,7 +626,7 @@ class IdentityHandler(BaseHandler): # require or validate it. See the following for context: # https://github.com/matrix-org/synapse/issues/5253#issuecomment-666246950 return data["mxid"] - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except IOError as e: logger.warning("Error from v1 identity server lookup: %s" % (e,)) @@ -655,7 +654,7 @@ class IdentityHandler(BaseHandler): "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server), {"access_token": id_access_token}, ) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") if not isinstance(hash_details, dict): @@ -727,7 +726,7 @@ class IdentityHandler(BaseHandler): }, headers=headers, ) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except Exception as e: logger.warning("Error when performing a v2 3pid lookup: %s", e) @@ -823,7 +822,7 @@ class IdentityHandler(BaseHandler): invite_config, {"Authorization": create_id_access_token_header(id_access_token)}, ) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: if e.code != 404: @@ -841,7 +840,7 @@ class IdentityHandler(BaseHandler): data = await self.blacklisting_http_client.post_json_get_json( url, invite_config ) - except TimeoutError: + except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: logger.warning( diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 8eb3638591..59b01b812c 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -16,8 +16,6 @@ import re from twisted.internet import task -from twisted.internet.defer import CancelledError -from twisted.python import failure from twisted.web.client import FileBodyProducer from synapse.api.errors import SynapseError @@ -26,19 +24,8 @@ from synapse.api.errors import SynapseError class RequestTimedOutError(SynapseError): """Exception representing timeout of an outbound request""" - def __init__(self): - super().__init__(504, "Timed out") - - -def cancelled_to_request_timed_out_error(value, timeout): - """Turns CancelledErrors into RequestTimedOutErrors. - - For use with async.add_timeout_to_deferred - """ - if isinstance(value, failure.Failure): - value.trap(CancelledError) - raise RequestTimedOutError() - return value + def __init__(self, msg): + super().__init__(504, msg) ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$") diff --git a/synapse/http/client.py b/synapse/http/client.py index 4694adc400..8324632cb6 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import urllib from io import BytesIO @@ -38,7 +37,7 @@ from zope.interface import implementer, provider from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE -from twisted.internet import defer, protocol, ssl +from twisted.internet import defer, error as twisted_error, protocol, ssl from twisted.internet.interfaces import ( IReactorPluggableNameResolver, IResolutionReceiver, @@ -46,17 +45,18 @@ from twisted.internet.interfaces import ( from twisted.internet.task import Cooperator from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone -from twisted.web.client import Agent, HTTPConnectionPool, readBody +from twisted.web.client import ( + Agent, + HTTPConnectionPool, + ResponseNeverReceived, + readBody, +) from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers from twisted.web.iweb import IResponse from synapse.api.errors import Codes, HttpResponseException, SynapseError -from synapse.http import ( - QuieterFileBodyProducer, - cancelled_to_request_timed_out_error, - redact_uri, -) +from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags @@ -332,8 +332,6 @@ class SimpleHttpClient: RequestTimedOutError if the request times out before the headers are read """ - # A small wrapper around self.agent.request() so we can easily attach - # counters to it outgoing_requests_counter.labels(method).inc() # log request but strip `access_token` (AS requests for example include this) @@ -362,15 +360,17 @@ class SimpleHttpClient: data=body_producer, headers=headers, **self._extra_treq_args - ) + ) # type: defer.Deferred + # we use our own timeout mechanism rather than treq's as a workaround # for https://twistedmatrix.com/trac/ticket/9534. request_deferred = timeout_deferred( - request_deferred, - 60, - self.hs.get_reactor(), - cancelled_to_request_timed_out_error, + request_deferred, 60, self.hs.get_reactor(), ) + + # turn timeouts into RequestTimedOutErrors + request_deferred.addErrback(_timeout_to_request_timed_out_error) + response = await make_deferred_yieldable(request_deferred) incoming_responses_counter.labels(method, response.code).inc() @@ -410,7 +410,7 @@ class SimpleHttpClient: parsed json Raises: - RequestTimedOutException: if there is a timeout before the response headers + RequestTimedOutError: if there is a timeout before the response headers are received. Note there is currently no timeout on reading the response body. @@ -461,7 +461,7 @@ class SimpleHttpClient: parsed json Raises: - RequestTimedOutException: if there is a timeout before the response headers + RequestTimedOutError: if there is a timeout before the response headers are received. Note there is currently no timeout on reading the response body. @@ -506,7 +506,7 @@ class SimpleHttpClient: Returns: Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON. Raises: - RequestTimedOutException: if there is a timeout before the response headers + RequestTimedOutError: if there is a timeout before the response headers are received. Note there is currently no timeout on reading the response body. @@ -538,7 +538,7 @@ class SimpleHttpClient: Returns: Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON. Raises: - RequestTimedOutException: if there is a timeout before the response headers + RequestTimedOutError: if there is a timeout before the response headers are received. Note there is currently no timeout on reading the response body. @@ -586,7 +586,7 @@ class SimpleHttpClient: Succeeds when we get a 2xx HTTP response, with the HTTP body as bytes. Raises: - RequestTimedOutException: if there is a timeout before the response headers + RequestTimedOutError: if there is a timeout before the response headers are received. Note there is currently no timeout on reading the response body. @@ -631,7 +631,7 @@ class SimpleHttpClient: headers, absolute URI of the response and HTTP response code. Raises: - RequestTimedOutException: if there is a timeout before the response headers + RequestTimedOutError: if there is a timeout before the response headers are received. Note there is currently no timeout on reading the response body. @@ -684,6 +684,18 @@ class SimpleHttpClient: ) +def _timeout_to_request_timed_out_error(f: Failure): + if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError): + # The TCP connection has its own timeout (set by the 'connectTimeout' param + # on the Agent), which raises twisted_error.TimeoutError exception. + raise RequestTimedOutError("Timeout connecting to remote server") + elif f.check(defer.TimeoutError, ResponseNeverReceived): + # this one means that we hit our overall timeout on the request + raise RequestTimedOutError("Timeout waiting for response from remote server") + + return f + + # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index b02c74ab2d..c23a4d7c0c 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -171,7 +171,7 @@ async def _handle_json_response( d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) body = await make_deferred_yieldable(d) - except TimeoutError as e: + except defer.TimeoutError as e: logger.warning( "{%s} [%s] Timed out reading response - %s %s", request.txn_id, @@ -655,10 +655,14 @@ class MatrixFederationHttpClient: long_retries (bool): whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response headers - (including connecting to the server), *for each attempt*. + timeout (int|None): number of milliseconds to wait for the response. self._default_timeout (60s) by default. + Note that we may make several attempts to send the request; this + timeout applies to the time spent waiting for response headers for + *each* attempt (including connection time) as well as the time spent + reading the response body after a 200 response. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. backoff_on_404 (bool): True if we should count a 404 response as @@ -704,8 +708,13 @@ class MatrixFederationHttpClient: timeout=timeout, ) + if timeout is not None: + _sec_timeout = timeout / 1000 + else: + _sec_timeout = self.default_timeout + body = await _handle_json_response( - self.reactor, self.default_timeout, request, response, start_ms + self.reactor, _sec_timeout, request, response, start_ms ) return body @@ -734,10 +743,14 @@ class MatrixFederationHttpClient: long_retries (bool): whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response headers - (including connecting to the server), *for each attempt*. + timeout (int|None): number of milliseconds to wait for the response. self._default_timeout (60s) by default. + Note that we may make several attempts to send the request; this + timeout applies to the time spent waiting for response headers for + *each* attempt (including connection time) as well as the time spent + reading the response body after a 200 response. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. @@ -801,10 +814,14 @@ class MatrixFederationHttpClient: args (dict|None): A dictionary used to create query strings, defaults to None. - timeout (int|None): number of milliseconds to wait for the response headers - (including connecting to the server), *for each attempt*. + timeout (int|None): number of milliseconds to wait for the response. self._default_timeout (60s) by default. + Note that we may make several attempts to send the request; this + timeout applies to the time spent waiting for response headers for + *each* attempt (including connection time) as well as the time spent + reading the response body after a 200 response. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. @@ -840,8 +857,13 @@ class MatrixFederationHttpClient: timeout=timeout, ) + if timeout is not None: + _sec_timeout = timeout / 1000 + else: + _sec_timeout = self.default_timeout + body = await _handle_json_response( - self.reactor, self.default_timeout, request, response, start_ms + self.reactor, _sec_timeout, request, response, start_ms ) return body @@ -865,10 +887,14 @@ class MatrixFederationHttpClient: long_retries (bool): whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response headers - (including connecting to the server), *for each attempt*. + timeout (int|None): number of milliseconds to wait for the response. self._default_timeout (60s) by default. + Note that we may make several attempts to send the request; this + timeout applies to the time spent waiting for response headers for + *each* attempt (including connection time) as well as the time spent + reading the response body after a 200 response. + ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. @@ -900,8 +926,13 @@ class MatrixFederationHttpClient: ignore_backoff=ignore_backoff, ) + if timeout is not None: + _sec_timeout = timeout / 1000 + else: + _sec_timeout = self.default_timeout + body = await _handle_json_response( - self.reactor, self.default_timeout, request, response, start_ms + self.reactor, _sec_timeout, request, response, start_ms ) return body diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 332da02a8d..e32d3f43e0 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -44,8 +44,11 @@ class ProxyAgent(_AgentBase): `BrowserLikePolicyForHTTPS`, so unless you have special requirements you can leave this as-is. - connectTimeout (float): The amount of time that this Agent will wait - for the peer to accept a connection. + connectTimeout (Optional[float]): The amount of time that this Agent will wait + for the peer to accept a connection, in seconds. If 'None', + HostnameEndpoint's default (30s) will be used. + + This is used for connections to both proxies and destination servers. bindAddress (bytes): The local address for client sockets to bind to. @@ -108,6 +111,15 @@ class ProxyAgent(_AgentBase): Returns: Deferred[IResponse]: completes when the header of the response has been received (regardless of the response status code). + + Can fail with: + SchemeNotSupported: if the uri is not http or https + + twisted.internet.error.TimeoutError if the server we are connecting + to (proxy or destination) does not accept a connection before + connectTimeout. + + ... other things too. """ uri = uri.strip() if not _VALID_URI.match(uri): diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 67ce9a5f39..382f0cf3f0 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -449,18 +449,8 @@ class ReadWriteLock: R = TypeVar("R") -def _cancelled_to_timed_out_error(value: R, timeout: float) -> R: - if isinstance(value, failure.Failure): - value.trap(CancelledError) - raise defer.TimeoutError(timeout, "Deferred") - return value - - def timeout_deferred( - deferred: defer.Deferred, - timeout: float, - reactor: IReactorTime, - on_timeout_cancel: Optional[Callable[[Any, float], Any]] = None, + deferred: defer.Deferred, timeout: float, reactor: IReactorTime, ) -> defer.Deferred: """The in built twisted `Deferred.addTimeout` fails to time out deferreds that have a canceller that throws exceptions. This method creates a new @@ -469,27 +459,21 @@ def timeout_deferred( (See https://twistedmatrix.com/trac/ticket/9534) - NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred + NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred. + + NOTE: the TimeoutError raised by the resultant deferred is + twisted.internet.defer.TimeoutError, which is *different* to the built-in + TimeoutError, as well as various other TimeoutErrors you might have imported. Args: deferred: The Deferred to potentially timeout. timeout: Timeout in seconds reactor: The twisted reactor to use - on_timeout_cancel: A callable which is called immediately - after the deferred times out, and not if this deferred is - otherwise cancelled before the timeout. - It takes an arbitrary value, which is the value of the deferred at - that exact point in time (probably a CancelledError Failure), and - the timeout. - - The default callable (if none is provided) will translate a - CancelledError Failure into a defer.TimeoutError. Returns: - A new Deferred. + A new Deferred, which will errback with defer.TimeoutError on timeout. """ - new_d = defer.Deferred() timed_out = [False] @@ -502,18 +486,23 @@ def timeout_deferred( except: # noqa: E722, if we throw any exception it'll break time outs logger.exception("Canceller failed during timeout") + # the cancel() call should have set off a chain of errbacks which + # will have errbacked new_d, but in case it hasn't, errback it now. + if not new_d.called: - new_d.errback(defer.TimeoutError(timeout, "Deferred")) + new_d.errback(defer.TimeoutError("Timed out after %gs" % (timeout,))) delayed_call = reactor.callLater(timeout, time_it_out) - def convert_cancelled(value): - if timed_out[0]: - to_call = on_timeout_cancel or _cancelled_to_timed_out_error - return to_call(value, timeout) + def convert_cancelled(value: failure.Failure): + # if the orgininal deferred was cancelled, and our timeout has fired, then + # the reason it was cancelled was due to our timeout. Turn the CancelledError + # into a TimeoutError. + if timed_out[0] and value.check(CancelledError): + raise defer.TimeoutError("Timed out after %gs" % (timeout,)) return value - deferred.addBoth(convert_cancelled) + deferred.addErrback(convert_cancelled) def cancel_timeout(result): # stop the pending call to cancel the deferred if it's been fired diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index 5604af3795..212484a7fe 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -318,14 +318,14 @@ class FederationClientTests(HomeserverTestCase): r = self.successResultOf(d) self.assertEqual(r.code, 200) - def test_client_headers_no_body(self): + @parameterized.expand(["get_json", "post_json", "delete_json", "put_json"]) + def test_timeout_reading_body(self, method_name: str): """ If the HTTP request is connected, but gets no response before being - timed out, it'll give a ResponseNeverReceived. + timed out, it'll give a RequestSendFailed with can_retry. """ - d = defer.ensureDeferred( - self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) - ) + method = getattr(self.cl, method_name) + d = defer.ensureDeferred(method("testserv:8008", "foo/bar", timeout=10000)) self.pump() @@ -349,7 +349,9 @@ class FederationClientTests(HomeserverTestCase): self.reactor.advance(10.5) f = self.failureResultOf(d) - self.assertIsInstance(f.value, TimeoutError) + self.assertIsInstance(f.value, RequestSendFailed) + self.assertTrue(f.value.can_retry) + self.assertIsInstance(f.value.inner_exception, defer.TimeoutError) def test_client_requires_trailing_slashes(self): """ diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py new file mode 100644 index 0000000000..a1cf0862d4 --- /dev/null +++ b/tests/http/test_simple_client.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mock import Mock + +from netaddr import IPSet + +from twisted.internet import defer +from twisted.internet.error import DNSLookupError + +from synapse.http import RequestTimedOutError +from synapse.http.client import SimpleHttpClient +from synapse.server import HomeServer + +from tests.unittest import HomeserverTestCase + + +class SimpleHttpClientTests(HomeserverTestCase): + def prepare(self, reactor, clock, hs: "HomeServer"): + # Add a DNS entry for a test server + self.reactor.lookups["testserv"] = "1.2.3.4" + + self.cl = hs.get_simple_http_client() + + def test_dns_error(self): + """ + If the DNS lookup returns an error, it will bubble up. + """ + d = defer.ensureDeferred(self.cl.get_json("http://testserv2:8008/foo/bar")) + self.pump() + + f = self.failureResultOf(d) + self.assertIsInstance(f.value, DNSLookupError) + + def test_client_connection_refused(self): + d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar")) + + self.pump() + + # Nothing happened yet + self.assertNoResult(d) + + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8008) + e = Exception("go away") + factory.clientConnectionFailed(None, e) + self.pump(0.5) + + f = self.failureResultOf(d) + + self.assertIs(f.value, e) + + def test_client_never_connect(self): + """ + If the HTTP request is not connected and is timed out, it'll give a + ConnectingCancelledError or TimeoutError. + """ + d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar")) + + self.pump() + + # Nothing happened yet + self.assertNoResult(d) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + self.assertEqual(clients[0][0], "1.2.3.4") + self.assertEqual(clients[0][1], 8008) + + # Deferred is still without a result + self.assertNoResult(d) + + # Push by enough to time it out + self.reactor.advance(120) + f = self.failureResultOf(d) + + self.assertIsInstance(f.value, RequestTimedOutError) + + def test_client_connect_no_response(self): + """ + If the HTTP request is connected, but gets no response before being + timed out, it'll give a ResponseNeverReceived. + """ + d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar")) + + self.pump() + + # Nothing happened yet + self.assertNoResult(d) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + self.assertEqual(clients[0][0], "1.2.3.4") + self.assertEqual(clients[0][1], 8008) + + conn = Mock() + client = clients[0][2].buildProtocol(None) + client.makeConnection(conn) + + # Deferred is still without a result + self.assertNoResult(d) + + # Push by enough to time it out + self.reactor.advance(120) + f = self.failureResultOf(d) + + self.assertIsInstance(f.value, RequestTimedOutError) + + def test_client_ip_range_blacklist(self): + """Ensure that Synapse does not try to connect to blacklisted IPs""" + + # Add some DNS entries we'll blacklist + self.reactor.lookups["internal"] = "127.0.0.1" + self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337" + ip_blacklist = IPSet(["127.0.0.0/8", "fe80::/64"]) + + cl = SimpleHttpClient(self.hs, ip_blacklist=ip_blacklist) + + # Try making a GET request to a blacklisted IPv4 address + # ------------------------------------------------------ + # Make the request + d = defer.ensureDeferred(cl.get_json("http://internal:8008/foo/bar")) + self.pump(1) + + # Check that it was unable to resolve the address + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 0) + + self.failureResultOf(d, DNSLookupError) + + # Try making a POST request to a blacklisted IPv6 address + # ------------------------------------------------------- + # Make the request + d = defer.ensureDeferred( + cl.post_json_get_json("http://internalv6:8008/foo/bar", {}) + ) + + # Move the reactor forwards + self.pump(1) + + # Check that it was unable to resolve the address + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 0) + + # Check that it was due to a blacklisted DNS lookup + self.failureResultOf(d, DNSLookupError) + + # Try making a GET request to a non-blacklisted IPv4 address + # ---------------------------------------------------------- + # Make the request + d = defer.ensureDeferred(cl.get_json("http://testserv:8008/foo/bar")) + + # Nothing has happened yet + self.assertNoResult(d) + + # Move the reactor forwards + self.pump(1) + + # Check that it was able to resolve the address + clients = self.reactor.tcpClients + self.assertNotEqual(len(clients), 0) + + # Connection will still fail as this IP address does not resolve to anything + self.failureResultOf(d, RequestTimedOutError) -- cgit 1.5.1 From 1c6b8752b891c1a25524d8dfaa8efb7176c0dbec Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 29 Sep 2020 12:36:44 +0100 Subject: Only assert valid next_link params when provided (#8417) Broken in https://github.com/matrix-org/synapse/pull/8275 and has yet to be put in a release. Fixes https://github.com/matrix-org/synapse/issues/8418. `next_link` is an optional parameter. However, we were checking whether the `next_link` param was valid, even if it wasn't provided. In that case, `next_link` was `None`, which would clearly not be a valid URL. This would prevent password reset and other operations if `next_link` was not provided, and the `next_link_domain_whitelist` config option was set. --- changelog.d/8417.feature | 1 + synapse/rest/client/v2_alpha/account.py | 15 +++++++++------ tests/rest/client/v2_alpha/test_account.py | 6 ++++++ 3 files changed, 16 insertions(+), 6 deletions(-) create mode 100644 changelog.d/8417.feature (limited to 'tests') diff --git a/changelog.d/8417.feature b/changelog.d/8417.feature new file mode 100644 index 0000000000..17549c3df3 --- /dev/null +++ b/changelog.d/8417.feature @@ -0,0 +1 @@ +Add a config option to specify a whitelist of domains that a user can be redirected to after validating their email or phone number. \ No newline at end of file diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index c3ce0f6259..9245214f36 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -103,8 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, next_link) + if next_link: + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to @@ -379,8 +380,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, next_link) + if next_link: + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) existing_user_id = await self.store.get_user_id_by_threepid("email", email) @@ -453,8 +455,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, next_link) + if next_link: + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 93f899d861..ae2cd67f35 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -732,6 +732,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) def test_next_link_domain_whitelist(self): """Tests next_link parameters must fit the whitelist if provided""" + + # Ensure not providing a next_link parameter still works + self._request_token( + "something@example.com", "some_secret", next_link=None, expect_code=200, + ) + self._request_token( "something@example.com", "some_secret", -- cgit 1.5.1 From 8676d8ab2e5667d7c12774effc64b3ab99344a8d Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Tue, 29 Sep 2020 13:11:02 +0100 Subject: Filter out appservices from mau count (#8404) This is an attempt to fix #8403. --- changelog.d/8404.misc | 1 + synapse/storage/databases/main/monthly_active_users.py | 9 ++++++++- tests/storage/test_monthly_active_users.py | 17 ++++++++++++++++- 3 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 changelog.d/8404.misc (limited to 'tests') diff --git a/changelog.d/8404.misc b/changelog.d/8404.misc new file mode 100644 index 0000000000..7aadded6c1 --- /dev/null +++ b/changelog.d/8404.misc @@ -0,0 +1 @@ +Do not include appservice users when calculating the total MAU for a server. diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e0cedd1aac..e93aad33cd 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -41,7 +41,14 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """ def _count_users(txn): - sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" + # Exclude app service users + sql = """ + SELECT COALESCE(count(*), 0) + FROM monthly_active_users + LEFT JOIN users + ON monthly_active_users.user_id=users.name + WHERE (users.appservice_id IS NULL OR users.appservice_id = ''); + """ txn.execute(sql) (count,) = txn.fetchone() return count diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 643072bbaf..8d97b6d4cd 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -137,6 +137,21 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(count, 1) + def test_appservice_user_not_counted_in_mau(self): + self.get_success( + self.store.register_user( + user_id="@appservice_user:server", appservice_id="wibble" + ) + ) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 0) + + d = self.store.upsert_monthly_active_user("@appservice_user:server") + self.get_success(d) + + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 0) + def test_user_last_seen_monthly_active(self): user_id1 = "@user1:server" user_id2 = "@user2:server" @@ -383,7 +398,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(self.store.upsert_monthly_active_user(appservice2_user1)) count = self.get_success(self.store.get_monthly_active_count()) - self.assertEqual(count, 4) + self.assertEqual(count, 1) d = self.store.get_monthly_active_count_by_service() result = self.get_success(d) -- cgit 1.5.1 From b1433bf231370636b817ffa01e6cda5a567cfafe Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Sep 2020 16:42:19 +0100 Subject: Don't table scan events on worker startup (#8419) * Fix table scan of events on worker startup. This happened because we assumed "new" writers had an initial stream position of 0, so the replication code tried to fetch all events written by the instance between 0 and the current position. Instead, set the initial position of new writers to the current persisted up to position, on the assumption that new writers won't have written anything before that point. * Consider old writers coming back as "new". Otherwise we'd try and fetch entries between the old stale token and the current position, even though it won't have written any rows. Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/8419.feature | 1 + synapse/storage/util/id_generators.py | 26 +++++++++++++++++++++++++- tests/storage/test_id_generators.py | 18 ++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8419.feature (limited to 'tests') diff --git a/changelog.d/8419.feature b/changelog.d/8419.feature new file mode 100644 index 0000000000..b363e929ea --- /dev/null +++ b/changelog.d/8419.feature @@ -0,0 +1 @@ +Add experimental support for sharding event persister. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 4fd7573e26..02fbb656e8 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -273,6 +273,19 @@ class MultiWriterIdGenerator: # Load the current positions of all writers for the stream. if self._writers: + # We delete any stale entries in the positions table. This is + # important if we add back a writer after a long time; we want to + # consider that a "new" writer, rather than using the old stale + # entry here. + sql = """ + DELETE FROM stream_positions + WHERE + stream_name = ? + AND instance_name != ALL(?) + """ + sql = self._db.engine.convert_param_style(sql) + cur.execute(sql, (self._stream_name, self._writers)) + sql = """ SELECT instance_name, stream_id FROM stream_positions WHERE stream_name = ? @@ -453,11 +466,22 @@ class MultiWriterIdGenerator: """Returns the position of the given writer. """ + # If we don't have an entry for the given instance name, we assume it's a + # new writer. + # + # For new writers we assume their initial position to be the current + # persisted up to position. This stops Synapse from doing a full table + # scan when a new writer announces itself over replication. with self._lock: - return self._return_factor * self._current_positions.get(instance_name, 0) + return self._return_factor * self._current_positions.get( + instance_name, self._persisted_upto_position + ) def get_positions(self) -> Dict[str, int]: """Get a copy of the current positon map. + + Note that this won't necessarily include all configured writers if some + writers haven't written anything yet. """ with self._lock: diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 4558bee7be..392b08832b 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -390,17 +390,28 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # Initial config has two writers id_gen = self._create_id_generator("first", writers=["first", "second"]) self.assertEqual(id_gen.get_persisted_upto_position(), 3) + self.assertEqual(id_gen.get_current_token_for_writer("first"), 3) + self.assertEqual(id_gen.get_current_token_for_writer("second"), 5) # New config removes one of the configs. Note that if the writer is # removed from config we assume that it has been shut down and has # finished persisting, hence why the persisted upto position is 5. id_gen_2 = self._create_id_generator("second", writers=["second"]) self.assertEqual(id_gen_2.get_persisted_upto_position(), 5) + self.assertEqual(id_gen_2.get_current_token_for_writer("second"), 5) # This config points to a single, previously unused writer. id_gen_3 = self._create_id_generator("third", writers=["third"]) self.assertEqual(id_gen_3.get_persisted_upto_position(), 5) + # For new writers we assume their initial position to be the current + # persisted up to position. This stops Synapse from doing a full table + # scan when a new writer comes along. + self.assertEqual(id_gen_3.get_current_token_for_writer("third"), 5) + + id_gen_4 = self._create_id_generator("fourth", writers=["third"]) + self.assertEqual(id_gen_4.get_current_token_for_writer("third"), 5) + # Check that we get a sane next stream ID with this new config. async def _get_next_async(): @@ -410,6 +421,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(_get_next_async()) self.assertEqual(id_gen_3.get_persisted_upto_position(), 6) + # If we add back the old "first" then we shouldn't see the persisted up + # to position revert back to 3. + id_gen_5 = self._create_id_generator("five", writers=["first", "third"]) + self.assertEqual(id_gen_5.get_persisted_upto_position(), 6) + self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6) + self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6) + def test_sequence_consistency(self): """Test that we error out if the table and sequence diverges. """ -- cgit 1.5.1 From ea70f1c362dc4bd6c0f8a67e16ed0971fe095e5b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Sep 2020 21:48:33 +0100 Subject: Various clean ups to room stream tokens. (#8423) --- changelog.d/8423.misc | 1 + synapse/events/__init__.py | 6 ++-- synapse/handlers/admin.py | 2 +- synapse/handlers/device.py | 4 +-- synapse/handlers/initial_sync.py | 3 +- synapse/handlers/pagination.py | 5 ++- synapse/handlers/room.py | 4 +-- synapse/handlers/sync.py | 20 ++++++++---- synapse/notifier.py | 4 +-- synapse/replication/tcp/client.py | 6 ++-- synapse/rest/admin/__init__.py | 3 +- synapse/storage/databases/main/stream.py | 38 +++++++++++++---------- synapse/storage/persist_events.py | 5 ++- synapse/types.py | 53 ++++++++++++++++++++------------ tests/rest/client/v1/test_rooms.py | 8 ++--- tests/storage/test_purge.py | 10 +++--- 16 files changed, 96 insertions(+), 76 deletions(-) create mode 100644 changelog.d/8423.misc (limited to 'tests') diff --git a/changelog.d/8423.misc b/changelog.d/8423.misc new file mode 100644 index 0000000000..7260e3fa41 --- /dev/null +++ b/changelog.d/8423.misc @@ -0,0 +1 @@ +Various refactors to simplify stream token handling. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index bf800a3852..dc49df0812 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -23,7 +23,7 @@ from typing import Dict, Optional, Tuple, Type from unpaddedbase64 import encode_base64 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions -from synapse.types import JsonDict +from synapse.types import JsonDict, RoomStreamToken from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze @@ -118,8 +118,8 @@ class _EventInternalMetadata: # XXX: These are set by StreamWorkerStore._set_before_and_after. # I'm pretty sure that these are never persisted to the database, so shouldn't # be here - before = DictProperty("before") # type: str - after = DictProperty("after") # type: str + before = DictProperty("before") # type: RoomStreamToken + after = DictProperty("after") # type: RoomStreamToken order = DictProperty("order") # type: Tuple[int, int] def get_dict(self) -> JsonDict: diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index dd981c597e..1ce2091b46 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -153,7 +153,7 @@ class AdminHandler(BaseHandler): if not events: break - from_key = RoomStreamToken.parse(events[-1].internal_metadata.after) + from_key = events[-1].internal_metadata.after events = await filter_events_for_client(self.storage, user_id, events) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 4149520d6c..b9d9098104 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -29,7 +29,6 @@ from synapse.api.errors import ( from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import ( - RoomStreamToken, StreamToken, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -113,8 +112,7 @@ class DeviceWorkerHandler(BaseHandler): set_tag("user_id", user_id) set_tag("from_token", from_token) - now_room_id = self.store.get_room_max_stream_ordering() - now_room_key = RoomStreamToken(None, now_room_id) + now_room_key = self.store.get_room_max_token() room_ids = await self.store.get_rooms_for_user(user_id) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 8cd7eb22a3..43f15435de 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -325,7 +325,8 @@ class InitialSyncHandler(BaseHandler): if limit is None: limit = 10 - stream_token = await self.store.get_stream_token_for_event(member_event_id) + leave_position = await self.store.get_position_for_event(member_event_id) + stream_token = leave_position.to_room_stream_token() messages, token = await self.store.get_recent_events_for_room( room_id, limit=limit, end_token=stream_token diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index a0b3bdb5e0..d6779a4b44 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -25,7 +25,7 @@ from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig -from synapse.types import Requester, RoomStreamToken +from synapse.types import Requester from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client @@ -373,10 +373,9 @@ class PaginationHandler: # case "JOIN" would have been returned. assert member_event_id - leave_token_str = await self.store.get_topological_token_for_event( + leave_token = await self.store.get_topological_token_for_event( member_event_id ) - leave_token = RoomStreamToken.parse(leave_token_str) assert leave_token.topological is not None if leave_token.topological < curr_topo: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 11bf146bed..836b3f381a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1134,14 +1134,14 @@ class RoomEventSource: events[:] = events[:limit] if events: - end_key = RoomStreamToken.parse(events[-1].internal_metadata.after) + end_key = events[-1].internal_metadata.after else: end_key = to_key return (events, end_key) def get_current_key(self) -> RoomStreamToken: - return RoomStreamToken(None, self.store.get_room_max_stream_ordering()) + return self.store.get_room_max_token() def get_current_key_for_room(self, room_id: str) -> Awaitable[str]: return self.store.get_room_events_max_id(room_id) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index e948efef2e..bfe2583002 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -519,7 +519,7 @@ class SyncHandler: if len(recents) > timeline_limit: limited = True recents = recents[-timeline_limit:] - room_key = RoomStreamToken.parse(recents[0].internal_metadata.before) + room_key = recents[0].internal_metadata.before prev_batch_token = now_token.copy_and_replace("room_key", room_key) @@ -1595,16 +1595,24 @@ class SyncHandler: if leave_events: leave_event = leave_events[-1] - leave_stream_token = await self.store.get_stream_token_for_event( + leave_position = await self.store.get_position_for_event( leave_event.event_id ) - leave_token = since_token.copy_and_replace( - "room_key", leave_stream_token - ) - if since_token and since_token.is_after(leave_token): + # If the leave event happened before the since token then we + # bail. + if since_token and not leave_position.persisted_after( + since_token.room_key + ): continue + # We can safely convert the position of the leave event into a + # stream token as it'll only be used in the context of this + # room. (c.f. the docstring of `to_room_stream_token`). + leave_token = since_token.copy_and_replace( + "room_key", leave_position.to_room_stream_token() + ) + # If this is an out of band message, like a remote invite # rejection, we include it in the recents batch. Otherwise, we # let _load_filtered_recents handle fetching the correct diff --git a/synapse/notifier.py b/synapse/notifier.py index 441b3d15e2..59415f6f88 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -163,7 +163,7 @@ class _NotifierUserStream: """ # Immediately wake up stream if something has already since happened # since their last token. - if self.last_notified_token.is_after(token): + if self.last_notified_token != token: return _NotificationListener(defer.succeed(self.current_token)) else: return _NotificationListener(self.notify_deferred.observe()) @@ -470,7 +470,7 @@ class Notifier: async def check_for_updates( before_token: StreamToken, after_token: StreamToken ) -> EventStreamResult: - if not after_token.is_after(before_token): + if after_token == before_token: return EventStreamResult([], (from_token, from_token)) events = [] # type: List[EventBase] diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 55af3d41ea..e165429cad 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import ( EventsStreamEventRow, EventsStreamRow, ) -from synapse.types import PersistedEventPosition, RoomStreamToken, UserID +from synapse.types import PersistedEventPosition, UserID from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -152,9 +152,7 @@ class ReplicationDataHandler: if event.type == EventTypes.Member: extra_users = (UserID.from_string(event.state_key),) - max_token = RoomStreamToken( - None, self.store.get_room_max_stream_ordering() - ) + max_token = self.store.get_room_max_token() event_pos = PersistedEventPosition(instance_name, token) self.notifier.on_new_room_event( event, event_pos, max_token, extra_users diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 5c5f00b213..ba53f66f02 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -109,7 +109,8 @@ class PurgeHistoryRestServlet(RestServlet): if event.room_id != room_id: raise SynapseError(400, "Event is for wrong room.") - token = await self.store.get_topological_token_for_event(event_id) + room_token = await self.store.get_topological_token_for_event(event_id) + token = str(room_token) logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) elif "purge_up_to_ts" in body: diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 92e96468b4..37249f1e3f 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -35,7 +35,6 @@ what sort order was used: - topological tokems: "t%d-%d", where the integers map to the topological and stream ordering columns respectively. """ - import abc import logging from collections import namedtuple @@ -54,7 +53,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -from synapse.types import Collection, RoomStreamToken +from synapse.types import Collection, PersistedEventPosition, RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache if TYPE_CHECKING: @@ -305,6 +304,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): def get_room_min_stream_ordering(self) -> int: raise NotImplementedError() + def get_room_max_token(self) -> RoomStreamToken: + return RoomStreamToken(None, self.get_room_max_stream_ordering()) + async def get_room_events_stream_for_rooms( self, room_ids: Collection[str], @@ -611,26 +613,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): allow_none=allow_none, ) - async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken: - """The stream token for an event - Args: - event_id: The id of the event to look up a stream token for. - Raises: - StoreError if the event wasn't in the database. - Returns: - A stream token. + async def get_position_for_event(self, event_id: str) -> PersistedEventPosition: + """Get the persisted position for an event """ - stream_id = await self.get_stream_id_for_event(event_id) - return RoomStreamToken(None, stream_id) + row = await self.db_pool.simple_select_one( + table="events", + keyvalues={"event_id": event_id}, + retcols=("stream_ordering", "instance_name"), + desc="get_position_for_event", + ) + + return PersistedEventPosition( + row["instance_name"] or "master", row["stream_ordering"] + ) - async def get_topological_token_for_event(self, event_id: str) -> str: + async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken: """The stream token for an event Args: event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: - A "t%d-%d" topological token. + A `RoomStreamToken` topological token. """ row = await self.db_pool.simple_select_one( table="events", @@ -638,7 +642,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", ) - return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) + return RoomStreamToken(row["topological_ordering"], row["stream_ordering"]) async def get_current_topological_token(self, room_id: str, stream_key: int) -> int: """Gets the topological token in a room after or at the given stream @@ -687,8 +691,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): else: topo = None internal = event.internal_metadata - internal.before = str(RoomStreamToken(topo, stream - 1)) - internal.after = str(RoomStreamToken(topo, stream)) + internal.before = RoomStreamToken(topo, stream - 1) + internal.after = RoomStreamToken(topo, stream) internal.order = (int(topo) if topo else 0, int(stream)) async def get_events_around( diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index ded6cf9655..72939f3984 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -229,7 +229,7 @@ class EventsPersistenceStorage: defer.gatherResults(deferreds, consumeErrors=True) ) - return RoomStreamToken(None, self.main_store.get_current_events_token()) + return self.main_store.get_room_max_token() async def persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False @@ -247,11 +247,10 @@ class EventsPersistenceStorage: await make_deferred_yieldable(deferred) - max_persisted_id = self.main_store.get_current_events_token() event_stream_id = event.internal_metadata.stream_ordering pos = PersistedEventPosition(self._instance_name, event_stream_id) - return pos, RoomStreamToken(None, max_persisted_id) + return pos, self.main_store.get_room_max_token() def _maybe_start_persisting(self, room_id: str): async def persisting_queue(item): diff --git a/synapse/types.py b/synapse/types.py index ec39f9e1e8..02bcc197ec 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -413,6 +413,18 @@ class RoomStreamToken: pass raise SynapseError(400, "Invalid token %r" % (string,)) + def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken": + """Return a new token such that if an event is after both this token and + the other token, then its after the returned token too. + """ + + if self.topological or other.topological: + raise Exception("Can't advance topological tokens") + + max_stream = max(self.stream, other.stream) + + return RoomStreamToken(None, max_stream) + def as_tuple(self) -> Tuple[Optional[int], int]: return (self.topological, self.stream) @@ -458,31 +470,20 @@ class StreamToken: def room_stream_id(self): return self.room_key.stream - def is_after(self, other): - """Does this token contain events that the other doesn't?""" - return ( - (other.room_stream_id < self.room_stream_id) - or (int(other.presence_key) < int(self.presence_key)) - or (int(other.typing_key) < int(self.typing_key)) - or (int(other.receipt_key) < int(self.receipt_key)) - or (int(other.account_data_key) < int(self.account_data_key)) - or (int(other.push_rules_key) < int(self.push_rules_key)) - or (int(other.to_device_key) < int(self.to_device_key)) - or (int(other.device_list_key) < int(self.device_list_key)) - or (int(other.groups_key) < int(self.groups_key)) - ) - def copy_and_advance(self, key, new_value) -> "StreamToken": """Advance the given key in the token to a new value if and only if the new value is after the old value. """ - new_token = self.copy_and_replace(key, new_value) if key == "room_key": - new_id = new_token.room_stream_id - old_id = self.room_stream_id - else: - new_id = int(getattr(new_token, key)) - old_id = int(getattr(self, key)) + new_token = self.copy_and_replace( + "room_key", self.room_key.copy_and_advance(new_value) + ) + return new_token + + new_token = self.copy_and_replace(key, new_value) + new_id = int(getattr(new_token, key)) + old_id = int(getattr(self, key)) + if old_id < new_id: return new_token else: @@ -509,6 +510,18 @@ class PersistedEventPosition: def persisted_after(self, token: RoomStreamToken) -> bool: return token.stream < self.stream + def to_room_stream_token(self) -> RoomStreamToken: + """Converts the position to a room stream token such that events + persisted in the same room after this position will be after the + returned `RoomStreamToken`. + + Note: no guarentees are made about ordering w.r.t. events in other + rooms. + """ + # Doing the naive thing satisfies the desired properties described in + # the docstring. + return RoomStreamToken(None, self.stream) + class ThirdPartyInstanceID( namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id")) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 0a567b032f..a3287011e9 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -902,15 +902,15 @@ class RoomMessageListTestCase(RoomBase): # Send a first message in the room, which will be removed by the purge. first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] - first_token = self.get_success( - store.get_topological_token_for_event(first_event_id) + first_token = str( + self.get_success(store.get_topological_token_for_event(first_event_id)) ) # Send a second message in the room, which won't be removed, and which we'll # use as the marker to purge events before. second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] - second_token = self.get_success( - store.get_topological_token_for_event(second_event_id) + second_token = str( + self.get_success(store.get_topological_token_for_event(second_event_id)) ) # Send a third event in the room to ensure we don't fall under any edge case diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 918387733b..723cd28933 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -47,8 +47,8 @@ class PurgeTests(HomeserverTestCase): storage = self.hs.get_storage() # Get the topological token - event = self.get_success( - store.get_topological_token_for_event(last["event_id"]) + event = str( + self.get_success(store.get_topological_token_for_event(last["event_id"])) ) # Purge everything before this topological token @@ -74,12 +74,10 @@ class PurgeTests(HomeserverTestCase): storage = self.hs.get_datastore() # Set the topological token higher than it should be - event = self.get_success( + token = self.get_success( storage.get_topological_token_for_event(last["event_id"]) ) - event = "t{}-{}".format( - *list(map(lambda x: x + 1, map(int, event[1:].split("-")))) - ) + event = "t{}-{}".format(token.topological + 1, token.stream + 1) # Purge everything before this topological token purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True)) -- cgit 1.5.1 From 6d2d42f8fb04599713d3e6e7fc3bc4c9b7063c9a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 29 Sep 2020 22:26:28 +0100 Subject: Rewrite BucketCollector This was a bit unweildy for what I wanted: in particular, I wanted to assign each measurement straight into a bucket, rather than storing an intermediate Counter which didn't do any bucketing at all. I've replaced it with something that is hopefully a bit easier to use. (I'm not entirely sure what the difference between a HistogramMetricFamily and a GaugeHistogramMetricFamily is, but given our counters can go down as well as up the latter *sounds* more accurate?) --- synapse/metrics/__init__.py | 115 ++++++++++++++++++------------ synapse/storage/databases/main/metrics.py | 26 +++---- tests/storage/test_event_metrics.py | 19 ++--- 3 files changed, 89 insertions(+), 71 deletions(-) (limited to 'tests') diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index a1f7ca3449..b8d2a8e8a9 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -15,6 +15,7 @@ import functools import gc +import itertools import logging import os import platform @@ -27,8 +28,8 @@ from prometheus_client import Counter, Gauge, Histogram from prometheus_client.core import ( REGISTRY, CounterMetricFamily, + GaugeHistogramMetricFamily, GaugeMetricFamily, - HistogramMetricFamily, ) from twisted.internet import reactor @@ -46,7 +47,7 @@ logger = logging.getLogger(__name__) METRICS_PREFIX = "/_synapse/metrics" running_on_pypy = platform.python_implementation() == "PyPy" -all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge, BucketCollector]] +all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge]] HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") @@ -205,63 +206,83 @@ class InFlightGauge: all_gauges[self.name] = self -@attr.s(slots=True, hash=True) -class BucketCollector: - """ - Like a Histogram, but allows buckets to be point-in-time instead of - incrementally added to. +class GaugeBucketCollector: + """Like a Histogram, but the buckets are Gauges which are updated atomically. - Args: - name (str): Base name of metric to be exported to Prometheus. - data_collector (callable -> dict): A synchronous callable that - returns a dict mapping bucket to number of items in the - bucket. If these buckets are not the same as the buckets - given to this class, they will be remapped into them. - buckets (list[float]): List of floats/ints of the buckets to - give to Prometheus. +Inf is ignored, if given. + The data is updated by calling `update_data` with an iterable of measurements. + We assume that the data is updated less frequently than it is reported to + Prometheus, and optimise for that case. """ - name = attr.ib() - data_collector = attr.ib() - buckets = attr.ib() + __slots__ = ("_name", "_documentation", "_bucket_bounds", "_metric") - def collect(self): + def __init__( + self, + name: str, + documentation: str, + buckets: Iterable[float], + registry=REGISTRY, + ): + """ + Args: + name: base name of metric to be exported to Prometheus. (a _bucket suffix + will be added.) + documentation: help text for the metric + buckets: The top bounds of the buckets to report + registry: metric registry to register with + """ + self._name = name + self._documentation = documentation - # Fetch the data -- this must be synchronous! - data = self.data_collector() + # the tops of the buckets + self._bucket_bounds = [float(b) for b in buckets] + if self._bucket_bounds != sorted(self._bucket_bounds): + raise ValueError("Buckets not in sorted order") - buckets = {} # type: Dict[float, int] + if self._bucket_bounds[-1] != float("inf"): + self._bucket_bounds.append(float("inf")) - res = [] - for x in data.keys(): - for i, bound in enumerate(self.buckets): - if x <= bound: - buckets[bound] = buckets.get(bound, 0) + data[x] + self._metric = self._values_to_metric([]) + registry.register(self) - for i in self.buckets: - res.append([str(i), buckets.get(i, 0)]) + def collect(self): + yield self._metric - res.append(["+Inf", sum(data.values())]) + def update_data(self, values: Iterable[float]): + """Update the data to be reported by the metric - metric = HistogramMetricFamily( - self.name, "", buckets=res, sum_value=sum(x * y for x, y in data.items()) + The existing data is cleared, and each measurement in the input is assigned + to the relevant bucket. + """ + self._metric = self._values_to_metric(values) + + def _values_to_metric(self, values: Iterable[float]) -> GaugeHistogramMetricFamily: + total = 0.0 + bucket_values = [0 for _ in self._bucket_bounds] + + for v in values: + # assign each value to a bucket + for i, bound in enumerate(self._bucket_bounds): + if v <= bound: + bucket_values[i] += 1 + break + + # ... and increment the sum + total += v + + # now, aggregate the bucket values so that they count the number of entries in + # that bucket or below. + accumulated_values = itertools.accumulate(bucket_values) + + return GaugeHistogramMetricFamily( + self._name, + self._documentation, + buckets=list( + zip((str(b) for b in self._bucket_bounds), accumulated_values) + ), + gsum_value=total, ) - yield metric - - def __attrs_post_init__(self): - self.buckets = [float(x) for x in self.buckets if x != "+Inf"] - if self.buckets != sorted(self.buckets): - raise ValueError("Buckets not sorted") - - self.buckets = tuple(self.buckets) - - if self.name in all_gauges.keys(): - logger.warning("%s already registered, reregistering" % (self.name,)) - REGISTRY.unregister(all_gauges.pop(self.name)) - - REGISTRY.register(self) - all_gauges[self.name] = self # diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 686052bd83..4efc093b9e 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -12,10 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import typing -from collections import Counter -from synapse.metrics import BucketCollector +from synapse.metrics import GaugeBucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -23,6 +21,14 @@ from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) +# Collect metrics on the number of forward extremities that exist. +_extremities_collecter = GaugeBucketCollector( + "synapse_forward_extremities", + "Number of rooms on the server with the given number of forward extremities" + " or fewer", + buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500], +) + class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): """Functions to pull various metrics from the DB, for e.g. phone home @@ -32,18 +38,6 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - # Collect metrics on the number of forward extremities that exist. - # Counter of number of extremities to count - self._current_forward_extremities_amount = ( - Counter() - ) # type: typing.Counter[int] - - BucketCollector( - "synapse_forward_extremities", - lambda: self._current_forward_extremities_amount, - buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], - ) - # Read the extrems every 60 minutes def read_forward_extremities(): # run as a background process to make sure that the database transactions @@ -65,7 +59,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): return txn.fetchall() res = await self.db_pool.runInteraction("read_forward_extremities", fetch) - self._current_forward_extremities_amount = Counter([x[0] for x in res]) + _extremities_collecter.update_data(x[0] for x in res) async def count_daily_messages(self): """ diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 949846fe33..3957471f3f 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -52,14 +52,14 @@ class ExtremStatisticsTestCase(HomeserverTestCase): self.reactor.advance(60 * 60 * 1000) self.pump(1) - items = set( + items = list( filter( lambda x: b"synapse_forward_extremities_" in x, - generate_latest(REGISTRY).split(b"\n"), + generate_latest(REGISTRY, emit_help=False).split(b"\n"), ) ) - expected = { + expected = [ b'synapse_forward_extremities_bucket{le="1.0"} 0.0', b'synapse_forward_extremities_bucket{le="2.0"} 2.0', b'synapse_forward_extremities_bucket{le="3.0"} 2.0', @@ -72,9 +72,12 @@ class ExtremStatisticsTestCase(HomeserverTestCase): b'synapse_forward_extremities_bucket{le="100.0"} 3.0', b'synapse_forward_extremities_bucket{le="200.0"} 3.0', b'synapse_forward_extremities_bucket{le="500.0"} 3.0', - b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', - b"synapse_forward_extremities_count 3.0", - b"synapse_forward_extremities_sum 10.0", - } - + # per https://docs.google.com/document/d/1KwV0mAXwwbvvifBvDKH_LU1YjyXE_wxCkHNoCGq1GX0/edit#heading=h.wghdjzzh72j9, + # "inf" is valid: "this includes variants such as inf" + b'synapse_forward_extremities_bucket{le="inf"} 3.0', + b"# TYPE synapse_forward_extremities_gcount gauge", + b"synapse_forward_extremities_gcount 3.0", + b"# TYPE synapse_forward_extremities_gsum gauge", + b"synapse_forward_extremities_gsum 10.0", + ] self.assertEqual(items, expected) -- cgit 1.5.1 From 8b40843392e2df80d4f1108295ae6acd972100b0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 30 Sep 2020 13:02:43 -0400 Subject: Allow additional SSO properties to be passed to the client (#8413) --- changelog.d/8413.feature | 1 + docs/sample_config.yaml | 8 ++ docs/sso_mapping_providers.md | 14 +++- docs/workers.md | 16 ++++ synapse/config/oidc_config.py | 8 ++ synapse/handlers/auth.py | 60 ++++++++++++++- synapse/handlers/oidc_handler.py | 56 +++++++++++++- synapse/rest/client/v1/login.py | 22 ++++-- tests/handlers/test_oidc.py | 160 +++++++++++++++++++++++++-------------- 9 files changed, 278 insertions(+), 67 deletions(-) create mode 100644 changelog.d/8413.feature (limited to 'tests') diff --git a/changelog.d/8413.feature b/changelog.d/8413.feature new file mode 100644 index 0000000000..abe40a901c --- /dev/null +++ b/changelog.d/8413.feature @@ -0,0 +1 @@ +Support passing additional single sign-on parameters to the client. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 70cc06a6d8..066844b5a9 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1748,6 +1748,14 @@ oidc_config: # #display_name_template: "{{ user.given_name }} {{ user.last_name }}" + # Jinja2 templates for extra attributes to send back to the client during + # login. + # + # Note that these are non-standard and clients will ignore them without modifications. + # + #extra_attributes: + #birthdate: "{{ user.birthdate }}" + # Enable CAS for registration and login. diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md index abea432343..32b06aa2c5 100644 --- a/docs/sso_mapping_providers.md +++ b/docs/sso_mapping_providers.md @@ -57,7 +57,7 @@ A custom mapping provider must specify the following methods: - This method must return a string, which is the unique identifier for the user. Commonly the ``sub`` claim of the response. * `map_user_attributes(self, userinfo, token)` - - This method should be async. + - This method must be async. - Arguments: - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user information from. @@ -66,6 +66,18 @@ A custom mapping provider must specify the following methods: - Returns a dictionary with two keys: - localpart: A required string, used to generate the Matrix ID. - displayname: An optional string, the display name for the user. +* `get_extra_attributes(self, userinfo, token)` + - This method must be async. + - Arguments: + - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user + information from. + - `token` - A dictionary which includes information necessary to make + further requests to the OpenID provider. + - Returns a dictionary that is suitable to be serialized to JSON. This + will be returned as part of the response during a successful login. + + Note that care should be taken to not overwrite any of the parameters + usually returned as part of the [login response](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-login). ### Default OpenID Mapping Provider diff --git a/docs/workers.md b/docs/workers.md index df0ac84d94..ad4d8ca9f2 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -243,6 +243,22 @@ for the room are in flight: ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/messages$ +Additionally, the following endpoints should be included if Synapse is configured +to use SSO (you only need to include the ones for whichever SSO provider you're +using): + + # OpenID Connect requests. + ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$ + ^/_synapse/oidc/callback$ + + # SAML requests. + ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$ + ^/_matrix/saml2/authn_response$ + + # CAS requests. + ^/_matrix/client/(api/v1|r0|unstable)/login/(cas|sso)/redirect$ + ^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$ + Note that a HTTP listener with `client` and `federation` resources must be configured in the `worker_listeners` option in the worker config. diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 70fc8a2f62..f924116819 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -204,6 +204,14 @@ class OIDCConfig(Config): # If unset, no displayname will be set. # #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}" + + # Jinja2 templates for extra attributes to send back to the client during + # login. + # + # Note that these are non-standard and clients will ignore them without modifications. + # + #extra_attributes: + #birthdate: "{{{{ user.birthdate }}}}" """.format( mapping_provider=DEFAULT_USER_MAPPING_PROVIDER ) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0322b60cfc..00eae92052 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -137,6 +137,15 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]: } +@attr.s(slots=True) +class SsoLoginExtraAttributes: + """Data we track about SAML2 sessions""" + + # time the session was created, in milliseconds + creation_time = attr.ib(type=int) + extra_attributes = attr.ib(type=JsonDict) + + class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 @@ -239,6 +248,10 @@ class AuthHandler(BaseHandler): # cast to tuple for use with str.startswith self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) + # A mapping of user ID to extra attributes to include in the login + # response. + self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes] + async def validate_user_via_ui_auth( self, requester: Requester, @@ -1165,6 +1178,7 @@ class AuthHandler(BaseHandler): registered_user_id: str, request: SynapseRequest, client_redirect_url: str, + extra_attributes: Optional[JsonDict] = None, ): """Having figured out a mxid for this user, complete the HTTP request @@ -1173,6 +1187,8 @@ class AuthHandler(BaseHandler): request: The request to complete. client_redirect_url: The URL to which to redirect the user at the end of the process. + extra_attributes: Extra attributes which will be passed to the client + during successful login. Must be JSON serializable. """ # If the account has been deactivated, do not proceed with the login # flow. @@ -1181,19 +1197,30 @@ class AuthHandler(BaseHandler): respond_with_html(request, 403, self._sso_account_deactivated_template) return - self._complete_sso_login(registered_user_id, request, client_redirect_url) + self._complete_sso_login( + registered_user_id, request, client_redirect_url, extra_attributes + ) def _complete_sso_login( self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str, + extra_attributes: Optional[JsonDict] = None, ): """ The synchronous portion of complete_sso_login. This exists purely for backwards compatibility of synapse.module_api.ModuleApi. """ + # Store any extra attributes which will be passed in the login response. + # Note that this is per-user so it may overwrite a previous value, this + # is considered OK since the newest SSO attributes should be most valid. + if extra_attributes: + self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes( + self._clock.time_msec(), extra_attributes, + ) + # Create a login token login_token = self.macaroon_gen.generate_short_term_login_token( registered_user_id @@ -1226,6 +1253,37 @@ class AuthHandler(BaseHandler): ) respond_with_html(request, 200, html) + async def _sso_login_callback(self, login_result: JsonDict) -> None: + """ + A login callback which might add additional attributes to the login response. + + Args: + login_result: The data to be sent to the client. Includes the user + ID and access token. + """ + # Expire attributes before processing. Note that there shouldn't be any + # valid logins that still have extra attributes. + self._expire_sso_extra_attributes() + + extra_attributes = self._extra_attributes.get(login_result["user_id"]) + if extra_attributes: + login_result.update(extra_attributes.extra_attributes) + + def _expire_sso_extra_attributes(self) -> None: + """ + Iterate through the mapping of user IDs to extra attributes and remove any that are no longer valid. + """ + # TODO This should match the amount of time the macaroon is valid for. + LOGIN_TOKEN_EXPIRATION_TIME = 2 * 60 * 1000 + expire_before = self._clock.time_msec() - LOGIN_TOKEN_EXPIRATION_TIME + to_expire = set() + for user_id, data in self._extra_attributes.items(): + if data.creation_time < expire_before: + to_expire.add(user_id) + for user_id in to_expire: + logger.debug("Expiring extra attributes for user %s", user_id) + del self._extra_attributes[user_id] + @staticmethod def add_query_param_to_url(url: str, param_name: str, param: Any): url_parts = list(urllib.parse.urlparse(url)) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 0e06e4408d..19cd652675 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -37,7 +37,7 @@ from synapse.config import ConfigError from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart from synapse.util import json_decoder if TYPE_CHECKING: @@ -707,6 +707,15 @@ class OidcHandler: self._render_error(request, "mapping_error", str(e)) return + # Mapping providers might not have get_extra_attributes: only call this + # method if it exists. + extra_attributes = None + get_extra_attributes = getattr( + self._user_mapping_provider, "get_extra_attributes", None + ) + if get_extra_attributes: + extra_attributes = await get_extra_attributes(userinfo, token) + # and finally complete the login if ui_auth_session_id: await self._auth_handler.complete_sso_ui_auth( @@ -714,7 +723,7 @@ class OidcHandler: ) else: await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url + user_id, request, client_redirect_url, extra_attributes ) def _generate_oidc_session_token( @@ -984,7 +993,7 @@ class OidcMappingProvider(Generic[C]): async def map_user_attributes( self, userinfo: UserInfo, token: Token ) -> UserAttribute: - """Map a ``UserInfo`` objects into user attributes. + """Map a `UserInfo` object into user attributes. Args: userinfo: An object representing the user given by the OIDC provider @@ -995,6 +1004,18 @@ class OidcMappingProvider(Generic[C]): """ raise NotImplementedError() + async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: + """Map a `UserInfo` object into additional attributes passed to the client during login. + + Args: + userinfo: An object representing the user given by the OIDC provider + token: A dict with the tokens returned by the provider + + Returns: + A dict containing additional attributes. Must be JSON serializable. + """ + return {} + # Used to clear out "None" values in templates def jinja_finalize(thing): @@ -1009,6 +1030,7 @@ class JinjaOidcMappingConfig: subject_claim = attr.ib() # type: str localpart_template = attr.ib() # type: Template display_name_template = attr.ib() # type: Optional[Template] + extra_attributes = attr.ib() # type: Dict[str, Template] class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): @@ -1047,10 +1069,28 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): % (e,) ) + extra_attributes = {} # type Dict[str, Template] + if "extra_attributes" in config: + extra_attributes_config = config.get("extra_attributes") or {} + if not isinstance(extra_attributes_config, dict): + raise ConfigError( + "oidc_config.user_mapping_provider.config.extra_attributes must be a dict" + ) + + for key, value in extra_attributes_config.items(): + try: + extra_attributes[key] = env.from_string(value) + except Exception as e: + raise ConfigError( + "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r" + % (key, e) + ) + return JinjaOidcMappingConfig( subject_claim=subject_claim, localpart_template=localpart_template, display_name_template=display_name_template, + extra_attributes=extra_attributes, ) def get_remote_user_id(self, userinfo: UserInfo) -> str: @@ -1071,3 +1111,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): display_name = None return UserAttribute(localpart=localpart, display_name=display_name) + + async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: + extras = {} # type: Dict[str, str] + for key, template in self._config.extra_attributes.items(): + try: + extras[key] = template.render(user=userinfo).strip() + except Exception as e: + # Log an error and skip this value (don't break login for this). + logger.error("Failed to render OIDC extra attribute %s: %s" % (key, e)) + return extras diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 250b03a025..b9347b87c7 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -284,9 +284,7 @@ class LoginRestServlet(RestServlet): self, user_id: str, login_submission: JsonDict, - callback: Optional[ - Callable[[Dict[str, str]], Awaitable[Dict[str, str]]] - ] = None, + callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, create_non_existent_users: bool = False, ) -> Dict[str, str]: """Called when we've successfully authed the user and now need to @@ -299,12 +297,12 @@ class LoginRestServlet(RestServlet): Args: user_id: ID of the user to register. login_submission: Dictionary of login information. - callback: Callback function to run after registration. + callback: Callback function to run after login. create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. Returns: - result: Dictionary of account information after successful registration. + result: Dictionary of account information after successful login. """ # Before we actually log them in we check if they've already logged in @@ -339,14 +337,24 @@ class LoginRestServlet(RestServlet): return result async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]: + """ + Handle the final stage of SSO login. + + Args: + login_submission: The JSON request body. + + Returns: + The body of the JSON response. + """ token = login_submission["token"] auth_handler = self.auth_handler user_id = await auth_handler.validate_short_term_login_token_and_get_user_id( token ) - result = await self._complete_login(user_id, login_submission) - return result + return await self._complete_login( + user_id, login_submission, self.auth_handler._sso_login_callback + ) async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: token = login_submission.get("token", None) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 5910772aa8..d5087e58be 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -21,7 +21,6 @@ from mock import Mock, patch import attr import pymacaroons -from twisted.internet import defer from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone @@ -87,6 +86,13 @@ class TestMappingProvider(OidcMappingProvider): async def map_user_attributes(self, userinfo, token): return {"localpart": userinfo["username"], "display_name": None} + # Do not include get_extra_attributes to test backwards compatibility paths. + + +class TestMappingProviderExtra(TestMappingProvider): + async def get_extra_attributes(self, userinfo, token): + return {"phone": userinfo["phone"]} + def simple_async_mock(return_value=None, raises=None): # AsyncMock is not available in python3.5, this mimics part of its behaviour @@ -126,7 +132,7 @@ class OidcHandlerTestCase(HomeserverTestCase): config = self.default_config() config["public_baseurl"] = BASE_URL - oidc_config = config.get("oidc_config", {}) + oidc_config = {} oidc_config["enabled"] = True oidc_config["client_id"] = CLIENT_ID oidc_config["client_secret"] = CLIENT_SECRET @@ -135,6 +141,10 @@ class OidcHandlerTestCase(HomeserverTestCase): oidc_config["user_mapping_provider"] = { "module": __name__ + ".TestMappingProvider", } + + # Update this config with what's in the default config so that + # override_config works as expected. + oidc_config.update(config.get("oidc_config", {})) config["oidc_config"] = oidc_config hs = self.setup_test_homeserver( @@ -165,11 +175,10 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET) @override_config({"oidc_config": {"discover": True}}) - @defer.inlineCallbacks def test_discovery(self): """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid - metadata = yield defer.ensureDeferred(self.handler.load_metadata()) + metadata = self.get_success(self.handler.load_metadata()) self.http_client.get_json.assert_called_once_with(WELL_KNOWN) self.assertEqual(metadata.issuer, ISSUER) @@ -181,43 +190,40 @@ class OidcHandlerTestCase(HomeserverTestCase): # subsequent calls should be cached self.http_client.reset_mock() - yield defer.ensureDeferred(self.handler.load_metadata()) + self.get_success(self.handler.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": COMMON_CONFIG}) - @defer.inlineCallbacks def test_no_discovery(self): """When discovery is disabled, it should not try to load from discovery document.""" - yield defer.ensureDeferred(self.handler.load_metadata()) + self.get_success(self.handler.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": COMMON_CONFIG}) - @defer.inlineCallbacks def test_load_jwks(self): """JWKS loading is done once (then cached) if used.""" - jwks = yield defer.ensureDeferred(self.handler.load_jwks()) + jwks = self.get_success(self.handler.load_jwks()) self.http_client.get_json.assert_called_once_with(JWKS_URI) self.assertEqual(jwks, {"keys": []}) # subsequent calls should be cached… self.http_client.reset_mock() - yield defer.ensureDeferred(self.handler.load_jwks()) + self.get_success(self.handler.load_jwks()) self.http_client.get_json.assert_not_called() # …unless forced self.http_client.reset_mock() - yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + self.get_success(self.handler.load_jwks(force=True)) self.http_client.get_json.assert_called_once_with(JWKS_URI) # Throw if the JWKS uri is missing with self.metadata_edit({"jwks_uri": None}): - with self.assertRaises(RuntimeError): - yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + self.get_failure(self.handler.load_jwks(force=True), RuntimeError) # Return empty key set if JWKS are not used self.handler._scopes = [] # not asking the openid scope self.http_client.get_json.reset_mock() - jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + jwks = self.get_success(self.handler.load_jwks(force=True)) self.http_client.get_json.assert_not_called() self.assertEqual(jwks, {"keys": []}) @@ -299,11 +305,10 @@ class OidcHandlerTestCase(HomeserverTestCase): # This should not throw self.handler._validate_metadata() - @defer.inlineCallbacks def test_redirect_request(self): """The redirect request has the right arguments & generates a valid session cookie.""" req = Mock(spec=["addCookie"]) - url = yield defer.ensureDeferred( + url = self.get_success( self.handler.handle_redirect_request(req, b"http://client/redirect") ) url = urlparse(url) @@ -343,20 +348,18 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(params["nonce"], [nonce]) self.assertEqual(redirect, "http://client/redirect") - @defer.inlineCallbacks def test_callback_error(self): """Errors from the provider returned in the callback are displayed.""" self.handler._render_error = Mock() request = Mock(args={}) request.args[b"error"] = [b"invalid_client"] - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_client", "") request.args[b"error_description"] = [b"some description"] - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_client", "some description") - @defer.inlineCallbacks def test_callback(self): """Code callback works and display errors if something went wrong. @@ -377,7 +380,7 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "foo", "preferred_username": "bar", } - user_id = UserID("foo", "domain.org") + user_id = "@foo:domain.org" self.handler._render_error = Mock(return_value=None) self.handler._exchange_code = simple_async_mock(return_value=token) self.handler._parse_id_token = simple_async_mock(return_value=userinfo) @@ -394,13 +397,12 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url = "http://client/redirect" user_agent = "Browser" ip_address = "10.0.0.1" - session = self.handler._generate_oidc_session_token( + request.getCookie.return_value = self.handler._generate_oidc_session_token( state=state, nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=None, ) - request.getCookie.return_value = session request.args = {} request.args[b"code"] = [code.encode("utf-8")] @@ -410,10 +412,10 @@ class OidcHandlerTestCase(HomeserverTestCase): request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")] request.getClientIP.return_value = ip_address - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, + user_id, request, client_redirect_url, {}, ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -427,13 +429,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._map_userinfo_to_user = simple_async_mock( raises=MappingException() ) - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error") self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) # Handle ID token errors self.handler._parse_id_token = simple_async_mock(raises=Exception()) - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") self.handler._auth_handler.complete_sso_login.reset_mock() @@ -444,10 +446,10 @@ class OidcHandlerTestCase(HomeserverTestCase): # With userinfo fetching self.handler._scopes = [] # do not ask the "openid" scope - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, + user_id, request, client_redirect_url, {}, ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() @@ -459,17 +461,16 @@ class OidcHandlerTestCase(HomeserverTestCase): # Handle userinfo fetching error self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") # Handle code exchange failure self.handler._exchange_code = simple_async_mock( raises=OidcError("invalid_request") ) - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request") - @defer.inlineCallbacks def test_callback_session(self): """The callback verifies the session presence and validity""" self.handler._render_error = Mock(return_value=None) @@ -478,20 +479,20 @@ class OidcHandlerTestCase(HomeserverTestCase): # Missing cookie request.args = {} request.getCookie.return_value = None - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("missing_session", "No session cookie found") # Missing session parameter request.args = {} request.getCookie.return_value = "session" - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request", "State parameter is missing") # Invalid cookie request.args = {} request.args[b"state"] = [b"state"] request.getCookie.return_value = "session" - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_session") # Mismatching session @@ -504,18 +505,17 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args = {} request.args[b"state"] = [b"mismatching state"] request.getCookie.return_value = session - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mismatching_session") # Valid session request.args = {} request.args[b"state"] = [b"state"] request.getCookie.return_value = session - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request") @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}}) - @defer.inlineCallbacks def test_exchange_code(self): """Code exchange behaves correctly and handles various error scenarios.""" token = {"type": "bearer"} @@ -524,7 +524,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) ) code = "code" - ret = yield defer.ensureDeferred(self.handler._exchange_code(code)) + ret = self.get_success(self.handler._exchange_code(code)) kwargs = self.http_client.request.call_args[1] self.assertEqual(ret, token) @@ -546,10 +546,9 @@ class OidcHandlerTestCase(HomeserverTestCase): body=b'{"error": "foo", "error_description": "bar"}', ) ) - with self.assertRaises(OidcError) as exc: - yield defer.ensureDeferred(self.handler._exchange_code(code)) - self.assertEqual(exc.exception.error, "foo") - self.assertEqual(exc.exception.error_description, "bar") + exc = self.get_failure(self.handler._exchange_code(code), OidcError) + self.assertEqual(exc.value.error, "foo") + self.assertEqual(exc.value.error_description, "bar") # Internal server error with no JSON body self.http_client.request = simple_async_mock( @@ -557,9 +556,8 @@ class OidcHandlerTestCase(HomeserverTestCase): code=500, phrase=b"Internal Server Error", body=b"Not JSON", ) ) - with self.assertRaises(OidcError) as exc: - yield defer.ensureDeferred(self.handler._exchange_code(code)) - self.assertEqual(exc.exception.error, "server_error") + exc = self.get_failure(self.handler._exchange_code(code), OidcError) + self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body self.http_client.request = simple_async_mock( @@ -569,17 +567,16 @@ class OidcHandlerTestCase(HomeserverTestCase): body=b'{"error": "internal_server_error"}', ) ) - with self.assertRaises(OidcError) as exc: - yield defer.ensureDeferred(self.handler._exchange_code(code)) - self.assertEqual(exc.exception.error, "internal_server_error") + + exc = self.get_failure(self.handler._exchange_code(code), OidcError) + self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field self.http_client.request = simple_async_mock( return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) ) - with self.assertRaises(OidcError) as exc: - yield defer.ensureDeferred(self.handler._exchange_code(code)) - self.assertEqual(exc.exception.error, "server_error") + exc = self.get_failure(self.handler._exchange_code(code), OidcError) + self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field self.http_client.request = simple_async_mock( @@ -587,9 +584,62 @@ class OidcHandlerTestCase(HomeserverTestCase): code=200, phrase=b"OK", body=b'{"error": "some_error"}', ) ) - with self.assertRaises(OidcError) as exc: - yield defer.ensureDeferred(self.handler._exchange_code(code)) - self.assertEqual(exc.exception.error, "some_error") + exc = self.get_failure(self.handler._exchange_code(code), OidcError) + self.assertEqual(exc.value.error, "some_error") + + @override_config( + { + "oidc_config": { + "user_mapping_provider": { + "module": __name__ + ".TestMappingProviderExtra" + } + } + } + ) + def test_extra_attributes(self): + """ + Login while using a mapping provider that implements get_extra_attributes. + """ + token = { + "type": "bearer", + "id_token": "id_token", + "access_token": "access_token", + } + userinfo = { + "sub": "foo", + "phone": "1234567", + } + user_id = "@foo:domain.org" + self.handler._exchange_code = simple_async_mock(return_value=token) + self.handler._parse_id_token = simple_async_mock(return_value=userinfo) + self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + self.handler._auth_handler.complete_sso_login = simple_async_mock() + request = Mock( + spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + ) + + state = "state" + client_redirect_url = "http://client/redirect" + request.getCookie.return_value = self.handler._generate_oidc_session_token( + state=state, + nonce="nonce", + client_redirect_url=client_redirect_url, + ui_auth_session_id=None, + ) + + request.args = {} + request.args[b"code"] = [b"code"] + request.args[b"state"] = [state.encode("utf-8")] + + request.requestHeaders = Mock(spec=["getRawHeaders"]) + request.requestHeaders.getRawHeaders.return_value = [b"Browser"] + request.getClientIP.return_value = "10.0.0.1" + + self.get_success(self.handler.handle_oidc_callback(request)) + + self.handler._auth_handler.complete_sso_login.assert_called_once_with( + user_id, request, client_redirect_url, {"phone": "1234567"}, + ) def test_map_userinfo_to_user(self): """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" -- cgit 1.5.1 From 7941372ec84786f85ae6d75fd2d7a4af5b72ac98 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 30 Sep 2020 20:29:19 +0100 Subject: Make token serializing/deserializing async (#8427) The idea is that in future tokens will encode a mapping of instance to position. However, we don't want to include the full instance name in the string representation, so instead we'll have a mapping between instance name and an immutable integer ID in the DB that we can use instead. We'll then do the lookup when we serialize/deserialize the token (we could alternatively pass around an `Instance` type that includes both the name and ID, but that turns out to be a lot more invasive). --- changelog.d/8427.misc | 1 + synapse/handlers/events.py | 4 +-- synapse/handlers/initial_sync.py | 14 ++++----- synapse/handlers/pagination.py | 8 ++--- synapse/handlers/room.py | 8 +++-- synapse/handlers/search.py | 8 ++--- synapse/rest/admin/__init__.py | 2 +- synapse/rest/client/v1/events.py | 3 +- synapse/rest/client/v1/initial_sync.py | 3 +- synapse/rest/client/v1/room.py | 11 +++++-- synapse/rest/client/v2_alpha/keys.py | 3 +- synapse/rest/client/v2_alpha/sync.py | 10 +++--- synapse/storage/databases/main/purge_events.py | 8 ++--- synapse/streams/config.py | 9 +++--- synapse/types.py | 43 +++++++++++++++++++++----- tests/rest/client/v1/test_rooms.py | 30 +++++++++++++----- tests/storage/test_purge.py | 9 ++++-- 17 files changed, 115 insertions(+), 59 deletions(-) create mode 100644 changelog.d/8427.misc (limited to 'tests') diff --git a/changelog.d/8427.misc b/changelog.d/8427.misc new file mode 100644 index 0000000000..c9656b9112 --- /dev/null +++ b/changelog.d/8427.misc @@ -0,0 +1 @@ +Make stream token serializing/deserializing async. diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 0875b74ea8..539b4fc32e 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -133,8 +133,8 @@ class EventStreamHandler(BaseHandler): chunk = { "chunk": chunks, - "start": tokens[0].to_string(), - "end": tokens[1].to_string(), + "start": await tokens[0].to_string(self.store), + "end": await tokens[1].to_string(self.store), } return chunk diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 43f15435de..39a85801c1 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -203,8 +203,8 @@ class InitialSyncHandler(BaseHandler): messages, time_now=time_now, as_client_event=as_client_event ) ), - "start": start_token.to_string(), - "end": end_token.to_string(), + "start": await start_token.to_string(self.store), + "end": await end_token.to_string(self.store), } d["state"] = await self._event_serializer.serialize_events( @@ -249,7 +249,7 @@ class InitialSyncHandler(BaseHandler): ], "account_data": account_data_events, "receipts": receipt, - "end": now_token.to_string(), + "end": await now_token.to_string(self.store), } return ret @@ -348,8 +348,8 @@ class InitialSyncHandler(BaseHandler): "chunk": ( await self._event_serializer.serialize_events(messages, time_now) ), - "start": start_token.to_string(), - "end": end_token.to_string(), + "start": await start_token.to_string(self.store), + "end": await end_token.to_string(self.store), }, "state": ( await self._event_serializer.serialize_events( @@ -447,8 +447,8 @@ class InitialSyncHandler(BaseHandler): "chunk": ( await self._event_serializer.serialize_events(messages, time_now) ), - "start": start_token.to_string(), - "end": end_token.to_string(), + "start": await start_token.to_string(self.store), + "end": await end_token.to_string(self.store), }, "state": state, "presence": presence, diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index d6779a4b44..2c2a633938 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -413,8 +413,8 @@ class PaginationHandler: if not events: return { "chunk": [], - "start": from_token.to_string(), - "end": next_token.to_string(), + "start": await from_token.to_string(self.store), + "end": await next_token.to_string(self.store), } state = None @@ -442,8 +442,8 @@ class PaginationHandler: events, time_now, as_client_event=as_client_event ) ), - "start": from_token.to_string(), - "end": next_token.to_string(), + "start": await from_token.to_string(self.store), + "end": await next_token.to_string(self.store), } if state: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 836b3f381a..d5f7c78edf 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1077,11 +1077,13 @@ class RoomContextHandler: # the token, which we replace. token = StreamToken.START - results["start"] = token.copy_and_replace( + results["start"] = await token.copy_and_replace( "room_key", results["start"] - ).to_string() + ).to_string(self.store) - results["end"] = token.copy_and_replace("room_key", results["end"]).to_string() + results["end"] = await token.copy_and_replace( + "room_key", results["end"] + ).to_string(self.store) return results diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 6a76c20d79..e9402e6e2e 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -362,13 +362,13 @@ class SearchHandler(BaseHandler): self.storage, user.to_string(), res["events_after"] ) - res["start"] = now_token.copy_and_replace( + res["start"] = await now_token.copy_and_replace( "room_key", res["start"] - ).to_string() + ).to_string(self.store) - res["end"] = now_token.copy_and_replace( + res["end"] = await now_token.copy_and_replace( "room_key", res["end"] - ).to_string() + ).to_string(self.store) if include_profile: senders = { diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index ba53f66f02..57cac22252 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -110,7 +110,7 @@ class PurgeHistoryRestServlet(RestServlet): raise SynapseError(400, "Event is for wrong room.") room_token = await self.store.get_topological_token_for_event(event_id) - token = str(room_token) + token = await room_token.to_string(self.store) logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) elif "purge_up_to_ts" in body: diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 985d994f6b..1ecb77aa26 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -33,6 +33,7 @@ class EventStreamRestServlet(RestServlet): super().__init__() self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True) @@ -44,7 +45,7 @@ class EventStreamRestServlet(RestServlet): if b"room_id" in request.args: room_id = request.args[b"room_id"][0].decode("ascii") - pagin_config = PaginationConfig.from_request(request) + pagin_config = await PaginationConfig.from_request(self.store, request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS if b"timeout" in request.args: try: diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index d7042786ce..91da0ee573 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -27,11 +27,12 @@ class InitialSyncRestServlet(RestServlet): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request) as_client_event = b"raw" not in request.args - pagination_config = PaginationConfig.from_request(request) + pagination_config = await PaginationConfig.from_request(self.store, request) include_archived = parse_boolean(request, "archived", default=False) content = await self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 7e64a2e0fe..b63389e5fe 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -451,6 +451,7 @@ class RoomMemberListRestServlet(RestServlet): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) @@ -465,7 +466,7 @@ class RoomMemberListRestServlet(RestServlet): if at_token_string is None: at_token = None else: - at_token = StreamToken.from_string(at_token_string) + at_token = await StreamToken.from_string(self.store, at_token_string) # let you filter down on particular memberships. # XXX: this may not be the best shape for this API - we could pass in a filter @@ -521,10 +522,13 @@ class RoomMessageListRestServlet(RestServlet): super().__init__() self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = PaginationConfig.from_request(request, default_limit=10) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) as_client_event = b"raw" not in request.args filter_str = parse_string(request, b"filter", encoding="utf-8") if filter_str: @@ -580,10 +584,11 @@ class RoomInitialSyncRestServlet(RestServlet): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = PaginationConfig.from_request(request) + pagination_config = await PaginationConfig.from_request(self.store, request) content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 7abd6ff333..55c4606569 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -180,6 +180,7 @@ class KeyChangesServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() + self.store = hs.get_datastore() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True) @@ -191,7 +192,7 @@ class KeyChangesServlet(RestServlet): # changes after the "to" as well as before. set_tag("to", parse_string(request, "to")) - from_token = StreamToken.from_string(from_token_string) + from_token = await StreamToken.from_string(self.store, from_token_string) user_id = requester.user.to_string() diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 51e395cc64..6779df952f 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -77,6 +77,7 @@ class SyncRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() + self.store = hs.get_datastore() self.sync_handler = hs.get_sync_handler() self.clock = hs.get_clock() self.filtering = hs.get_filtering() @@ -151,10 +152,9 @@ class SyncRestServlet(RestServlet): device_id=device_id, ) + since_token = None if since is not None: - since_token = StreamToken.from_string(since) - else: - since_token = None + since_token = await StreamToken.from_string(self.store, since) # send any outstanding server notices to the user. await self._server_notices_sender.on_user_syncing(user.to_string()) @@ -236,7 +236,7 @@ class SyncRestServlet(RestServlet): "leave": sync_result.groups.leave, }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, - "next_batch": sync_result.next_batch.to_string(), + "next_batch": await sync_result.next_batch.to_string(self.store), } @staticmethod @@ -413,7 +413,7 @@ class SyncRestServlet(RestServlet): result = { "timeline": { "events": serialized_timeline, - "prev_batch": room.timeline.prev_batch.to_string(), + "prev_batch": await room.timeline.prev_batch.to_string(self.store), "limited": room.timeline.limited, }, "state": {"events": serialized_state}, diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index d7a03cbf7d..ecfc6717b3 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -42,17 +42,17 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): The set of state groups that are referenced by deleted events. """ + parsed_token = await RoomStreamToken.parse(self, token) + return await self.db_pool.runInteraction( "purge_history", self._purge_history_txn, room_id, - token, + parsed_token, delete_local_events, ) - def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): - token = RoomStreamToken.parse(token_str) - + def _purge_history_txn(self, txn, room_id, token, delete_local_events): # Tables that should be pruned: # event_auth # event_backward_extremities diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 0bdf846edf..fdda21d165 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import Optional @@ -21,6 +20,7 @@ import attr from synapse.api.errors import SynapseError from synapse.http.servlet import parse_integer, parse_string from synapse.http.site import SynapseRequest +from synapse.storage.databases.main import DataStore from synapse.types import StreamToken logger = logging.getLogger(__name__) @@ -39,8 +39,9 @@ class PaginationConfig: limit = attr.ib(type=Optional[int]) @classmethod - def from_request( + async def from_request( cls, + store: "DataStore", request: SynapseRequest, raise_invalid_params: bool = True, default_limit: Optional[int] = None, @@ -54,13 +55,13 @@ class PaginationConfig: if from_tok == "END": from_tok = None # For backwards compat. elif from_tok: - from_tok = StreamToken.from_string(from_tok) + from_tok = await StreamToken.from_string(store, from_tok) except Exception: raise SynapseError(400, "'from' parameter is invalid") try: if to_tok: - to_tok = StreamToken.from_string(to_tok) + to_tok = await StreamToken.from_string(store, to_tok) except Exception: raise SynapseError(400, "'to' parameter is invalid") diff --git a/synapse/types.py b/synapse/types.py index 02bcc197ec..bd271f9f16 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -18,7 +18,17 @@ import re import string import sys from collections import namedtuple -from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Mapping, + MutableMapping, + Optional, + Tuple, + Type, + TypeVar, +) import attr from signedjson.key import decode_verify_key_bytes @@ -26,6 +36,9 @@ from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError +if TYPE_CHECKING: + from synapse.storage.databases.main import DataStore + # define a version of typing.Collection that works on python 3.5 if sys.version_info[:3] >= (3, 6, 0): from typing import Collection @@ -393,7 +406,7 @@ class RoomStreamToken: stream = attr.ib(type=int, validator=attr.validators.instance_of(int)) @classmethod - def parse(cls, string: str) -> "RoomStreamToken": + async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": try: if string[0] == "s": return cls(topological=None, stream=int(string[1:])) @@ -428,7 +441,7 @@ class RoomStreamToken: def as_tuple(self) -> Tuple[Optional[int], int]: return (self.topological, self.stream) - def __str__(self) -> str: + async def to_string(self, store: "DataStore") -> str: if self.topological is not None: return "t%d-%d" % (self.topological, self.stream) else: @@ -453,18 +466,32 @@ class StreamToken: START = None # type: StreamToken @classmethod - def from_string(cls, string): + async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": try: keys = string.split(cls._SEPARATOR) while len(keys) < len(attr.fields(cls)): # i.e. old token from before receipt_key keys.append("0") - return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:])) + return cls( + await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) + ) except Exception: raise SynapseError(400, "Invalid Token") - def to_string(self): - return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)]) + async def to_string(self, store: "DataStore") -> str: + return self._SEPARATOR.join( + [ + await self.room_key.to_string(store), + str(self.presence_key), + str(self.typing_key), + str(self.receipt_key), + str(self.account_data_key), + str(self.push_rules_key), + str(self.to_device_key), + str(self.device_list_key), + str(self.groups_key), + ] + ) @property def room_stream_id(self): @@ -493,7 +520,7 @@ class StreamToken: return attr.evolve(self, **{key: new_value}) -StreamToken.START = StreamToken.from_string("s0_0") +StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0) @attr.s(slots=True, frozen=True) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index a3287011e9..0d809d25d5 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -902,16 +902,18 @@ class RoomMessageListTestCase(RoomBase): # Send a first message in the room, which will be removed by the purge. first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] - first_token = str( - self.get_success(store.get_topological_token_for_event(first_event_id)) + first_token = self.get_success( + store.get_topological_token_for_event(first_event_id) ) + first_token_str = self.get_success(first_token.to_string(store)) # Send a second message in the room, which won't be removed, and which we'll # use as the marker to purge events before. second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] - second_token = str( - self.get_success(store.get_topological_token_for_event(second_event_id)) + second_token = self.get_success( + store.get_topological_token_for_event(second_event_id) ) + second_token_str = self.get_success(second_token.to_string(store)) # Send a third event in the room to ensure we don't fall under any edge case # due to our marker being the latest forward extremity in the room. @@ -921,7 +923,11 @@ class RoomMessageListTestCase(RoomBase): request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + % ( + self.room_id, + second_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), ) self.render(request) self.assertEqual(channel.code, 200, channel.json_body) @@ -936,7 +942,7 @@ class RoomMessageListTestCase(RoomBase): pagination_handler._purge_history( purge_id=purge_id, room_id=self.room_id, - token=second_token, + token=second_token_str, delete_local_events=True, ) ) @@ -946,7 +952,11 @@ class RoomMessageListTestCase(RoomBase): request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + % ( + self.room_id, + second_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), ) self.render(request) self.assertEqual(channel.code, 200, channel.json_body) @@ -960,7 +970,11 @@ class RoomMessageListTestCase(RoomBase): request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})), + % ( + self.room_id, + first_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), ) self.render(request) self.assertEqual(channel.code, 200, channel.json_body) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 723cd28933..cc1f3c53c5 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -47,12 +47,15 @@ class PurgeTests(HomeserverTestCase): storage = self.hs.get_storage() # Get the topological token - event = str( - self.get_success(store.get_topological_token_for_event(last["event_id"])) + token = self.get_success( + store.get_topological_token_for_event(last["event_id"]) ) + token_str = self.get_success(token.to_string(self.hs.get_datastore())) # Purge everything before this topological token - self.get_success(storage.purge_events.purge_history(self.room_id, event, True)) + self.get_success( + storage.purge_events.purge_history(self.room_id, token_str, True) + ) # 1-3 should fail and last will succeed, meaning that 1-3 are deleted # and last is not. -- cgit 1.5.1 From fa8934b175467d589dd34fae18639cac0d738fc9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 7 Oct 2020 15:15:57 +0100 Subject: Reduce serialization errors in MultiWriterIdGen (#8456) We call `_update_stream_positions_table_txn` a lot, which is an UPSERT that can conflict in `REPEATABLE READ` isolation level. Instead of doing a transaction consisting of a single query we may as well run it outside of a transaction. --- changelog.d/8456.misc | 1 + synapse/storage/database.py | 63 +++++++++++++++++++++++++++++++++-- synapse/storage/engines/_base.py | 17 ++++++++++ synapse/storage/engines/postgres.py | 10 +++++- synapse/storage/engines/sqlite.py | 10 ++++++ synapse/storage/util/id_generators.py | 12 ++++++- tests/storage/test_base.py | 1 + 7 files changed, 109 insertions(+), 5 deletions(-) create mode 100644 changelog.d/8456.misc (limited to 'tests') diff --git a/changelog.d/8456.misc b/changelog.d/8456.misc new file mode 100644 index 0000000000..ccd260069b --- /dev/null +++ b/changelog.d/8456.misc @@ -0,0 +1 @@ +Reduce number of serialization errors of `MultiWriterIdGenerator._update_table`. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 79ec8f119d..6116191b16 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -403,6 +403,24 @@ class DatabasePool: *args: Any, **kwargs: Any ) -> R: + """Start a new database transaction with the given connection. + + Note: The given func may be called multiple times under certain + failure modes. This is normally fine when in a standard transaction, + but care must be taken if the connection is in `autocommit` mode that + the function will correctly handle being aborted and retried half way + through its execution. + + Args: + conn + desc + after_callbacks + exception_callbacks + func + *args + **kwargs + """ + start = monotonic_time() txn_id = self._TXN_ID @@ -508,7 +526,12 @@ class DatabasePool: sql_txn_timer.labels(desc).observe(duration) async def runInteraction( - self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any + self, + desc: str, + func: "Callable[..., R]", + *args: Any, + db_autocommit: bool = False, + **kwargs: Any ) -> R: """Starts a transaction on the database and runs a given function @@ -518,6 +541,18 @@ class DatabasePool: database transaction (twisted.enterprise.adbapi.Transaction) as its first argument, followed by `args` and `kwargs`. + db_autocommit: Whether to run the function in "autocommit" mode, + i.e. outside of a transaction. This is useful for transactions + that are only a single query. + + Currently, this is only implemented for Postgres. SQLite will still + run the function inside a transaction. + + WARNING: This means that if func fails half way through then + the changes will *not* be rolled back. `func` may also get + called multiple times if the transaction is retried, so must + correctly handle that case. + args: positional args to pass to `func` kwargs: named args to pass to `func` @@ -538,6 +573,7 @@ class DatabasePool: exception_callbacks, func, *args, + db_autocommit=db_autocommit, **kwargs ) @@ -551,7 +587,11 @@ class DatabasePool: return cast(R, result) async def runWithConnection( - self, func: "Callable[..., R]", *args: Any, **kwargs: Any + self, + func: "Callable[..., R]", + *args: Any, + db_autocommit: bool = False, + **kwargs: Any ) -> R: """Wraps the .runWithConnection() method on the underlying db_pool. @@ -560,6 +600,9 @@ class DatabasePool: database connection (twisted.enterprise.adbapi.Connection) as its first argument, followed by `args` and `kwargs`. args: positional args to pass to `func` + db_autocommit: Whether to run the function in "autocommit" mode, + i.e. outside of a transaction. This is useful for transaction + that are only a single query. Currently only affects postgres. kwargs: named args to pass to `func` Returns: @@ -575,6 +618,13 @@ class DatabasePool: start_time = monotonic_time() def inner_func(conn, *args, **kwargs): + # We shouldn't be in a transaction. If we are then something + # somewhere hasn't committed after doing work. (This is likely only + # possible during startup, as `run*` will ensure changes are + # committed/rolled back before putting the connection back in the + # pool). + assert not self.engine.in_transaction(conn) + with LoggingContext("runWithConnection", parent_context) as context: sched_duration_sec = monotonic_time() - start_time sql_scheduling_timer.observe(sched_duration_sec) @@ -584,7 +634,14 @@ class DatabasePool: logger.debug("Reconnecting closed database connection") conn.reconnect() - return func(conn, *args, **kwargs) + try: + if db_autocommit: + self.engine.attempt_to_set_autocommit(conn, True) + + return func(conn, *args, **kwargs) + finally: + if db_autocommit: + self.engine.attempt_to_set_autocommit(conn, False) return await make_deferred_yieldable( self._db_pool.runWithConnection(inner_func, *args, **kwargs) diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index 908cbc79e3..d6d632dc10 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -97,3 +97,20 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): """Gets a string giving the server version. For example: '3.22.0' """ ... + + @abc.abstractmethod + def in_transaction(self, conn: Connection) -> bool: + """Whether the connection is currently in a transaction. + """ + ... + + @abc.abstractmethod + def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + """Attempt to set the connections autocommit mode. + + When True queries are run outside of transactions. + + Note: This has no effect on SQLite3, so callers still need to + commit/rollback the connections. + """ + ... diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index ff39281f85..7719ac32f7 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -15,7 +15,8 @@ import logging -from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup +from synapse.storage.engines._base import BaseDatabaseEngine, IncorrectDatabaseSetup +from synapse.storage.types import Connection logger = logging.getLogger(__name__) @@ -119,6 +120,7 @@ class PostgresEngine(BaseDatabaseEngine): cursor.execute("SET synchronous_commit TO OFF") cursor.close() + db_conn.commit() @property def can_native_upsert(self): @@ -171,3 +173,9 @@ class PostgresEngine(BaseDatabaseEngine): return "%i.%i" % (numver / 10000, numver % 10000) else: return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100) + + def in_transaction(self, conn: Connection) -> bool: + return conn.status != self.module.extensions.STATUS_READY # type: ignore + + def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + return conn.set_session(autocommit=autocommit) # type: ignore diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 8a0f8c89d1..5db0f0b520 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -17,6 +17,7 @@ import threading import typing from synapse.storage.engines import BaseDatabaseEngine +from synapse.storage.types import Connection if typing.TYPE_CHECKING: import sqlite3 # noqa: F401 @@ -86,6 +87,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): db_conn.create_function("rank", 1, _rank) db_conn.execute("PRAGMA foreign_keys = ON;") + db_conn.commit() def is_deadlock(self, error): return False @@ -105,6 +107,14 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): """ return "%i.%i.%i" % self.module.sqlite_version_info + def in_transaction(self, conn: Connection) -> bool: + return conn.in_transaction # type: ignore + + def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + # Twisted doesn't let us set attributes on the connections, so we can't + # set the connection to autocommit mode. + pass + # Following functions taken from: https://github.com/coleifer/peewee diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 48efbb5067..ad017207aa 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -24,6 +24,7 @@ from typing_extensions import Deque from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.types import Cursor from synapse.storage.util.sequence import PostgresSequenceGenerator logger = logging.getLogger(__name__) @@ -552,7 +553,7 @@ class MultiWriterIdGenerator: # do. break - def _update_stream_positions_table_txn(self, txn): + def _update_stream_positions_table_txn(self, txn: Cursor): """Update the `stream_positions` table with newly persisted position. """ @@ -602,10 +603,13 @@ class _MultiWriterCtxManager: stream_ids = attr.ib(type=List[int], factory=list) async def __aenter__(self) -> Union[int, List[int]]: + # It's safe to run this in autocommit mode as fetching values from a + # sequence ignores transaction semantics anyway. self.stream_ids = await self.id_gen._db.runInteraction( "_load_next_mult_id", self.id_gen._load_next_mult_id_txn, self.multiple_ids or 1, + db_autocommit=True, ) # Assert the fetched ID is actually greater than any ID we've already @@ -636,10 +640,16 @@ class _MultiWriterCtxManager: # # We only do this on the success path so that the persisted current # position points to a persisted row with the correct instance name. + # + # We do this in autocommit mode as a) the upsert works correctly outside + # transactions and b) reduces the amount of time the rows are locked + # for. If we don't do this then we'll often hit serialization errors due + # to the fact we default to REPEATABLE READ isolation levels. if self.id_gen._writers: await self.id_gen._db.runInteraction( "MultiWriterIdGenerator._update_table", self.id_gen._update_stream_positions_table_txn, + db_autocommit=True, ) return False diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 40ba652248..eac7e4dcd2 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -56,6 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): engine = create_engine(sqlite_config) fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False + fake_engine.in_transaction.return_value = False db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine) db._db_pool = self.db_pool -- cgit 1.5.1