From ab18441573dc14cea1fe4082b2a89b9d392a4b9f Mon Sep 17 00:00:00 2001 From: Šimon Brandner Date: Fri, 5 Aug 2022 17:09:33 +0200 Subject: Support stable identifiers for MSC2285: private read receipts. (#13273) This adds support for the stable identifiers of MSC2285 while continuing to support the unstable identifiers behind the configuration flag. These will be removed in a future version. --- tests/storage/test_receipts.py | 55 +++++++++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 17 deletions(-) (limited to 'tests/storage') diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index b1a8f8bba7..191c957fb5 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from parameterized import parameterized + from synapse.api.constants import ReceiptTypes from synapse.types import UserID, create_requester @@ -23,7 +25,7 @@ OUR_USER_ID = "@our:test" class ReceiptTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor, clock, homeserver) -> None: super().prepare(reactor, clock, homeserver) self.store = homeserver.get_datastores().main @@ -83,10 +85,15 @@ class ReceiptTestCase(HomeserverTestCase): ) ) - def test_return_empty_with_no_data(self): + def test_return_empty_with_no_data(self) -> None: res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + OUR_USER_ID, + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ], ) ) self.assertEqual(res, {}) @@ -94,7 +101,11 @@ class ReceiptTestCase(HomeserverTestCase): res = self.get_success( self.store.get_receipts_for_user_with_orderings( OUR_USER_ID, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ], ) ) self.assertEqual(res, {}) @@ -103,12 +114,19 @@ class ReceiptTestCase(HomeserverTestCase): self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ], ) ) self.assertEqual(res, None) - def test_get_receipts_for_user(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_get_receipts_for_user(self, receipt_type: str) -> None: # Send some events into the first room event1_1_id = self.create_and_send_event( self.room_id1, UserID.from_string(OTHER_USER_ID) @@ -126,14 +144,14 @@ class ReceiptTestCase(HomeserverTestCase): # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {} ) ) # Test we get the latest event when we want both private and public receipts res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + OUR_USER_ID, [ReceiptTypes.READ, receipt_type] ) ) self.assertEqual(res, {self.room_id1: event1_2_id}) @@ -146,7 +164,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test we get the latest event when we want only the public receipt res = self.get_success( - self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ_PRIVATE]) + self.store.get_receipts_for_user(OUR_USER_ID, [receipt_type]) ) self.assertEqual(res, {self.room_id1: event1_2_id}) @@ -169,17 +187,20 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + OUR_USER_ID, [ReceiptTypes.READ, receipt_type] ) ) self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id}) - def test_get_last_receipt_event_id_for_user(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_get_last_receipt_event_id_for_user(self, receipt_type: str) -> None: # Send some events into the first room event1_1_id = self.create_and_send_event( self.room_id1, UserID.from_string(OTHER_USER_ID) @@ -197,7 +218,7 @@ class ReceiptTestCase(HomeserverTestCase): # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {} ) ) @@ -206,7 +227,7 @@ class ReceiptTestCase(HomeserverTestCase): self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + [ReceiptTypes.READ, receipt_type], ) ) self.assertEqual(res, event1_2_id) @@ -222,7 +243,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test we get the latest event when we want only the private receipt res = self.get_success( self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] + OUR_USER_ID, self.room_id1, [receipt_type] ) ) self.assertEqual(res, event1_2_id) @@ -248,14 +269,14 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id2, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + [ReceiptTypes.READ, receipt_type], ) ) self.assertEqual(res, event2_1_id) -- cgit 1.5.1 From 5ce2887653f9905f1ce13f53149c1e713314a28b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 9 Aug 2022 07:56:16 -0400 Subject: Strengthen tests about deleted old push actions. (#13471) --- changelog.d/13471.misc | 1 + tests/storage/test_event_push_actions.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+) create mode 100644 changelog.d/13471.misc (limited to 'tests/storage') diff --git a/changelog.d/13471.misc b/changelog.d/13471.misc new file mode 100644 index 0000000000..b55ff32c76 --- /dev/null +++ b/changelog.d/13471.misc @@ -0,0 +1 @@ +Clean-up tests for notifications. diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index ba40124c8a..62fd4aeb2f 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -135,7 +135,22 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _assert_counts(1, 1, 0) # Delete old event push actions, this should not affect the (summarised) count. + # + # All event push actions are kept for 24 hours, so need to move forward + # in time. + self.pump(60 * 60 * 24) self.get_success(self.store._remove_old_push_actions_that_have_rotated()) + # Double check that the event push actions have been cleared (i.e. that + # any results *must* come from the summary). + result = self.get_success( + self.store.db_pool.simple_select_list( + table="event_push_actions", + keyvalues={"1": 1}, + retcols=("event_id",), + desc="", + ) + ) + self.assertEqual(result, []) _assert_counts(1, 1, 0) _mark_read(last_event_id) -- cgit 1.5.1 From d75512d19ebea6c0f9e38e9f55474fdb6da02b46 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 17 Aug 2022 11:42:01 +0200 Subject: Add forgotten status to Room Details API (#13503) --- changelog.d/13503.feature | 1 + docs/admin_api/rooms.md | 5 +- synapse/rest/admin/rooms.py | 1 + synapse/storage/databases/main/roommember.py | 24 ++++++++++ tests/rest/admin/test_room.py | 1 + tests/storage/test_roommember.py | 70 ++++++++++++++++++++++++++++ 6 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 changelog.d/13503.feature (limited to 'tests/storage') diff --git a/changelog.d/13503.feature b/changelog.d/13503.feature new file mode 100644 index 0000000000..4baabd1e32 --- /dev/null +++ b/changelog.d/13503.feature @@ -0,0 +1 @@ +Add forgotten status to Room Details API. \ No newline at end of file diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 9aa489e4a3..ac7c54c20e 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -302,6 +302,8 @@ The following fields are possible in the JSON response body: * `state_events` - Total number of state_events of a room. Complexity of the room. * `room_type` - The type of the room taken from the room's creation event; for example "m.space" if the room is a space. If the room does not define a type, the value will be `null`. +* `forgotten` - Whether all local users have + [forgotten](https://spec.matrix.org/latest/client-server-api/#leaving-rooms) the room. The API is: @@ -330,7 +332,8 @@ A response body like the following is returned: "guest_access": null, "history_visibility": "shared", "state_events": 93534, - "room_type": "m.space" + "room_type": "m.space", + "forgotten": false } ``` diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 9d953d58de..68054ffc28 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -303,6 +303,7 @@ class RoomRestServlet(RestServlet): members = await self.store.get_users_in_room(room_id) ret["joined_local_devices"] = await self.store.count_devices_by_users(members) + ret["forgotten"] = await self.store.is_locally_forgotten_room(room_id) return HTTPStatus.OK, ret diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 5e5f607a14..827c1f1efd 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -1215,6 +1215,30 @@ class RoomMemberWorkerStore(EventsWorkerStore): "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn ) + async def is_locally_forgotten_room(self, room_id: str) -> bool: + """Returns whether all local users have forgotten this room_id. + + Args: + room_id: The room ID to query. + + Returns: + Whether the room is forgotten. + """ + + sql = """ + SELECT count(*) > 0 FROM local_current_membership + INNER JOIN room_memberships USING (room_id, event_id) + WHERE + room_id = ? + AND forgotten = 0; + """ + + rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id) + + # `count(*)` returns always an integer + # If any rows still exist it means someone has not forgotten this room yet + return not rows[0][0] + async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]: """Get all rooms that the user has ever been in. diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index dd5000679a..fd6da557c1 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1633,6 +1633,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("history_visibility", channel.json_body) self.assertIn("state_events", channel.json_body) self.assertIn("room_type", channel.json_body) + self.assertIn("forgotten", channel.json_body) self.assertEqual(room_id_1, channel.json_body["room_id"]) def test_single_room_devices(self) -> None: diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 240b02cb9f..ceec690285 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -23,6 +23,7 @@ from synapse.util import Clock from tests import unittest from tests.server import TestHomeServer +from tests.test_utils import event_injection class RoomMemberStoreTestCase(unittest.HomeserverTestCase): @@ -157,6 +158,75 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # Check that alice's display name is now None self.assertEqual(row[0]["display_name"], None) + def test_room_is_locally_forgotten(self): + """Test that when the last local user has forgotten a room it is known as forgotten.""" + # join two local and one remote user + self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + self.get_success( + event_injection.inject_member_event(self.hs, self.room, self.u_bob, "join") + ) + self.get_success( + event_injection.inject_member_event( + self.hs, self.room, self.u_charlie.to_string(), "join" + ) + ) + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + # local users leave the room and the room is not forgotten + self.get_success( + event_injection.inject_member_event( + self.hs, self.room, self.u_alice, "leave" + ) + ) + self.get_success( + event_injection.inject_member_event(self.hs, self.room, self.u_bob, "leave") + ) + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + # first user forgets the room, room is not forgotten + self.get_success(self.store.forget(self.u_alice, self.room)) + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + # second (last local) user forgets the room and the room is forgotten + self.get_success(self.store.forget(self.u_bob, self.room)) + self.assertTrue( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + def test_join_locally_forgotten_room(self): + """Tests if a user joins a forgotten room the room is not forgotten anymore.""" + self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + # after leaving and forget the room, it is forgotten + self.get_success( + event_injection.inject_member_event( + self.hs, self.room, self.u_alice, "leave" + ) + ) + self.get_success(self.store.forget(self.u_alice, self.room)) + self.assertTrue( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + # after rejoin the room is not forgotten anymore + self.get_success( + event_injection.inject_member_event( + self.hs, self.room, self.u_alice, "join" + ) + ) + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: -- cgit 1.5.1 From 682dfcfc0db05d9c99b7615d950997535df4d533 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 30 Aug 2022 11:58:38 +0200 Subject: Fix that user cannot `/forget` rooms after the last member has left (#13546) --- changelog.d/13546.bugfix | 1 + synapse/handlers/room_member.py | 7 ++- tests/handlers/test_room_member.py | 93 +++++++++++++++++++++++++++++++++++++- tests/storage/test_roommember.py | 4 +- 4 files changed, 99 insertions(+), 6 deletions(-) create mode 100644 changelog.d/13546.bugfix (limited to 'tests/storage') diff --git a/changelog.d/13546.bugfix b/changelog.d/13546.bugfix new file mode 100644 index 0000000000..83bc3a61d2 --- /dev/null +++ b/changelog.d/13546.bugfix @@ -0,0 +1 @@ +Fix bug that user cannot `/forget` rooms after the last member has left the room. \ No newline at end of file diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 709682622f..e726997d83 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1925,8 +1925,11 @@ class RoomMemberMasterHandler(RoomMemberHandler): ]: raise SynapseError(400, "User %s in room %s" % (user_id, room_id)) - if membership: - await self.store.forget(user_id, room_id) + # In normal case this call is only required if `membership` is not `None`. + # But: After the last member had left the room, the background update + # `_background_remove_left_rooms` is deleting rows related to this room from + # the table `current_state_events` and `get_current_state_events` is `None`. + await self.store.forget(user_id, room_id) def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[str]: diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py index 1d13ed1e88..6bbfd5dc84 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py @@ -6,7 +6,7 @@ import synapse.rest.admin import synapse.rest.client.login import synapse.rest.client.room from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import LimitExceededError +from synapse.api.errors import LimitExceededError, SynapseError from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events import FrozenEventV3 from synapse.federation.federation_client import SendJoinResult @@ -17,7 +17,11 @@ from synapse.util import Clock from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request from tests.test_utils import make_awaitable -from tests.unittest import FederatingHomeserverTestCase, override_config +from tests.unittest import ( + FederatingHomeserverTestCase, + HomeserverTestCase, + override_config, +) class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): @@ -287,3 +291,88 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCa ), LimitExceededError, ) + + +class RoomMemberMasterHandlerTestCase(HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.client.login.register_servlets, + synapse.rest.client.room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.handler = hs.get_room_member_handler() + self.store = hs.get_datastores().main + + # Create two users. + self.alice = self.register_user("alice", "pass") + self.alice_ID = UserID.from_string(self.alice) + self.alice_token = self.login("alice", "pass") + self.bob = self.register_user("bob", "pass") + self.bob_ID = UserID.from_string(self.bob) + self.bob_token = self.login("bob", "pass") + + # Create a room on this homeserver. + self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token) + + def test_leave_and_forget(self) -> None: + """Tests that forget a room is successfully. The test is performed with two users, + as forgetting by the last user respectively after all users had left the + is a special edge case.""" + self.helper.join(self.room_id, user=self.bob, tok=self.bob_token) + + # alice is not the last room member that leaves and forgets the room + self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token) + self.get_success(self.handler.forget(self.alice_ID, self.room_id)) + self.assertTrue( + self.get_success(self.store.did_forget(self.alice, self.room_id)) + ) + + # the server has not forgotten the room + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room_id)) + ) + + def test_leave_and_forget_last_user(self) -> None: + """Tests that forget a room is successfully when the last user has left the room.""" + + # alice is the last room member that leaves and forgets the room + self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token) + self.get_success(self.handler.forget(self.alice_ID, self.room_id)) + self.assertTrue( + self.get_success(self.store.did_forget(self.alice, self.room_id)) + ) + + # the server has forgotten the room + self.assertTrue( + self.get_success(self.store.is_locally_forgotten_room(self.room_id)) + ) + + def test_forget_when_not_left(self) -> None: + """Tests that a user cannot not forgets a room that has not left.""" + self.get_failure(self.handler.forget(self.alice_ID, self.room_id), SynapseError) + + def test_rejoin_forgotten_by_user(self) -> None: + """Test that a user that has forgotten a room can do a re-join. + The room was not forgotten from the local server. + One local user is still member of the room.""" + self.helper.join(self.room_id, user=self.bob, tok=self.bob_token) + + self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token) + self.get_success(self.handler.forget(self.alice_ID, self.room_id)) + self.assertTrue( + self.get_success(self.store.did_forget(self.alice, self.room_id)) + ) + + # the server has not forgotten the room + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room_id)) + ) + + self.helper.join(self.room_id, user=self.alice, tok=self.alice_token) + # TODO: A join to a room does not invalidate the forgotten cache + # see https://github.com/matrix-org/synapse/issues/13262 + self.store.did_forget.invalidate_all() + self.assertFalse( + self.get_success(self.store.did_forget(self.alice, self.room_id)) + ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index ceec690285..8794401823 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -158,7 +158,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # Check that alice's display name is now None self.assertEqual(row[0]["display_name"], None) - def test_room_is_locally_forgotten(self): + def test_room_is_locally_forgotten(self) -> None: """Test that when the last local user has forgotten a room it is known as forgotten.""" # join two local and one remote user self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) @@ -199,7 +199,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): self.get_success(self.store.is_locally_forgotten_room(self.room)) ) - def test_join_locally_forgotten_room(self): + def test_join_locally_forgotten_room(self) -> None: """Tests if a user joins a forgotten room the room is not forgotten anymore.""" self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) self.assertFalse( -- cgit 1.5.1 From 0e99f07952edcb6396654e34da50ddeb0a211067 Mon Sep 17 00:00:00 2001 From: Šimon Brandner Date: Thu, 1 Sep 2022 14:31:54 +0200 Subject: Remove support for unstable private read receipts (#13653) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- changelog.d/13653.removal | 1 + synapse/api/constants.py | 1 - synapse/config/experimental.py | 3 -- synapse/handlers/receipts.py | 29 +++---------- synapse/replication/tcp/client.py | 5 +-- synapse/rest/client/notifications.py | 1 - synapse/rest/client/read_marker.py | 2 - synapse/rest/client/receipts.py | 2 - synapse/rest/client/versions.py | 1 - .../storage/databases/main/event_push_actions.py | 2 - tests/handlers/test_receipts.py | 48 ++++++---------------- tests/rest/client/test_sync.py | 37 +++++------------ tests/storage/test_receipts.py | 34 ++++++--------- 13 files changed, 44 insertions(+), 122 deletions(-) create mode 100644 changelog.d/13653.removal (limited to 'tests/storage') diff --git a/changelog.d/13653.removal b/changelog.d/13653.removal new file mode 100644 index 0000000000..eb075d4517 --- /dev/null +++ b/changelog.d/13653.removal @@ -0,0 +1 @@ +Remove support for unstable [private read receipts](https://github.com/matrix-org/matrix-spec-proposals/pull/2285). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index c73aea622a..c178ddf070 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -258,7 +258,6 @@ class GuestAccess: class ReceiptTypes: READ: Final = "m.read" READ_PRIVATE: Final = "m.read.private" - UNSTABLE_READ_PRIVATE: Final = "org.matrix.msc2285.read.private" FULLY_READ: Final = "m.fully_read" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index c1ff417539..260db49cad 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -32,9 +32,6 @@ class ExperimentalConfig(Config): # MSC2716 (importing historical messages) self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) - # MSC2285 (unstable private read receipts) - self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False) - # MSC3244 (room version capabilities) self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index d4a866b346..d2bdb9c8be 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -163,10 +163,7 @@ class ReceiptsHandler: if not is_new: return - if self.federation_sender and receipt_type not in ( - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ): + if self.federation_sender and receipt_type != ReceiptTypes.READ_PRIVATE: await self.federation_sender.send_read_receipt(receipt) @@ -206,38 +203,24 @@ class ReceiptEventSource(EventSource[int, JsonDict]): for event_id, orig_event_content in room.get("content", {}).items(): event_content = orig_event_content # If there are private read receipts, additional logic is necessary. - if ( - ReceiptTypes.READ_PRIVATE in event_content - or ReceiptTypes.UNSTABLE_READ_PRIVATE in event_content - ): + if ReceiptTypes.READ_PRIVATE in event_content: # Make a copy without private read receipts to avoid leaking # other user's private read receipts.. event_content = { receipt_type: receipt_value for receipt_type, receipt_value in event_content.items() - if receipt_type - not in ( - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ) + if receipt_type != ReceiptTypes.READ_PRIVATE } # Copy the current user's private read receipt from the # original content, if it exists. - user_private_read_receipt = orig_event_content.get( - ReceiptTypes.READ_PRIVATE, {} - ).get(user_id, None) + user_private_read_receipt = orig_event_content[ + ReceiptTypes.READ_PRIVATE + ].get(user_id, None) if user_private_read_receipt: event_content[ReceiptTypes.READ_PRIVATE] = { user_id: user_private_read_receipt } - user_unstable_private_read_receipt = orig_event_content.get( - ReceiptTypes.UNSTABLE_READ_PRIVATE, {} - ).get(user_id, None) - if user_unstable_private_read_receipt: - event_content[ReceiptTypes.UNSTABLE_READ_PRIVATE] = { - user_id: user_unstable_private_read_receipt - } # Include the event if there is at least one non-private read # receipt or the current user has a private read receipt. diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 1ed7230e32..e4f2201c92 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -416,10 +416,7 @@ class FederationSenderHandler: if not self._is_mine_id(receipt.user_id): continue # Private read receipts never get sent over federation. - if receipt.receipt_type in ( - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ): + if receipt.receipt_type == ReceiptTypes.READ_PRIVATE: continue receipt_info = ReadReceipt( receipt.room_id, diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index a73322a6a4..61268e3af1 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -62,7 +62,6 @@ class NotificationsServlet(RestServlet): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ], ) diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index aaad8b233f..5e53096539 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -45,8 +45,6 @@ class ReadMarkerRestServlet(RestServlet): ReceiptTypes.FULLY_READ, ReceiptTypes.READ_PRIVATE, } - if hs.config.experimental.msc2285_enabled: - self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE) async def on_POST( self, request: SynapseRequest, room_id: str diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index c6108fc5eb..5b7fad7402 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -49,8 +49,6 @@ class ReceiptRestServlet(RestServlet): ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ, } - if hs.config.experimental.msc2285_enabled: - self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE) async def on_POST( self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c9a830cbac..c516cda95d 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -95,7 +95,6 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, # Supports receiving private read receipts as per MSC2285 "org.matrix.msc2285.stable": True, # TODO: Remove when MSC2285 becomes a part of the spec - "org.matrix.msc2285": self.config.experimental.msc2285_enabled, # Supports filtering of /publicRooms by room type as per MSC3827 "org.matrix.msc3827.stable": True, # Adds support for importing historical messages as per MSC2716 diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 9f410d69de..f4a07de2a3 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -274,7 +274,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas receipt_types=( ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ), ) @@ -468,7 +467,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ( ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ), ) diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 5f70a2db79..b55238650c 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -15,8 +15,6 @@ from copy import deepcopy from typing import List -from parameterized import parameterized - from synapse.api.constants import EduTypes, ReceiptTypes from synapse.types import JsonDict @@ -27,16 +25,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.event_source = hs.get_event_sources().sources.receipt - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_filters_out_private_receipt(self, receipt_type: str) -> None: + def test_filters_out_private_receipt(self) -> None: self._test_filters_private( [ { "content": { "$1435641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, } @@ -50,18 +45,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): [], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_filters_out_private_receipt_and_ignores_rest( - self, receipt_type: str - ) -> None: + def test_filters_out_private_receipt_and_ignores_rest(self) -> None: self._test_filters_private( [ { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -94,18 +84,15 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest( - self, receipt_type: str + self, ) -> None: self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -175,18 +162,15 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest( - self, receipt_type: str + self, ) -> None: self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -262,16 +246,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_leaves_our_private_and_their_public(self, receipt_type: str) -> None: + def test_leaves_our_private_and_their_public(self) -> None: self._test_filters_private( [ { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@me:server.org": { "ts": 1436451550453, }, @@ -296,7 +277,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@me:server.org": { "ts": 1436451550453, }, @@ -319,16 +300,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_we_do_not_mutate(self, receipt_type: str) -> None: + def test_we_do_not_mutate(self) -> None: """Ensure the input values are not modified.""" events = [ { "content": { "$1435641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, } diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index de0dec8539..0af643ecd9 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -391,7 +391,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() - config["experimental_features"] = {"msc2285_enabled": True} return self.setup_test_homeserver(config=config) @@ -413,17 +412,14 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Join the second user self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_private_read_receipts(self, receipt_type: str) -> None: + def test_private_read_receipts(self) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) # Send a private read receipt to tell the server the first user's message was read channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -432,10 +428,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that the first user can't see the other user's private read receipt self.assertIsNone(self._get_read_receipt()) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_public_receipt_can_override_private(self, receipt_type: str) -> None: + def test_public_receipt_can_override_private(self) -> None: """ Sending a public read receipt to the same event which has a private read receipt should cause that receipt to become public. @@ -446,7 +439,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Send a private read receipt channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -465,10 +458,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that we did override the private read receipt self.assertNotEqual(self._get_read_receipt(), None) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_private_receipt_cannot_override_public(self, receipt_type: str) -> None: + def test_private_receipt_cannot_override_public(self) -> None: """ Sending a private read receipt to the same event which has a public read receipt should cause no change. @@ -489,7 +479,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Send a private read receipt channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -554,7 +544,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): config = super().default_config() config["experimental_features"] = { "msc2654_enabled": True, - "msc2285_enabled": True, } return config @@ -601,10 +590,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): tok=self.tok, ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_unread_counts(self, receipt_type: str) -> None: + def test_unread_counts(self) -> None: """Tests that /sync returns the right value for the unread count (MSC2654).""" # Check that our own messages don't increase the unread count. @@ -638,7 +624,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Send a read receipt to tell the server we've read the latest event. channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok, ) @@ -726,7 +712,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}", {}, access_token=self.tok, ) @@ -738,7 +724,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ] ) def test_read_receipts_only_go_down(self, receipt_type: str) -> None: @@ -752,7 +737,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Read last event channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}", {}, access_token=self.tok, ) @@ -763,7 +748,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # read receipt go up to an older event channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res1['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res1['event_id']}", {}, access_token=self.tok, ) diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index 191c957fb5..c89bfff241 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from parameterized import parameterized from synapse.api.constants import ReceiptTypes from synapse.types import UserID, create_requester @@ -92,7 +91,6 @@ class ReceiptTestCase(HomeserverTestCase): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ], ) ) @@ -104,7 +102,6 @@ class ReceiptTestCase(HomeserverTestCase): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ], ) ) @@ -117,16 +114,12 @@ class ReceiptTestCase(HomeserverTestCase): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ], ) ) self.assertEqual(res, None) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_get_receipts_for_user(self, receipt_type: str) -> None: + def test_get_receipts_for_user(self) -> None: # Send some events into the first room event1_1_id = self.create_and_send_event( self.room_id1, UserID.from_string(OTHER_USER_ID) @@ -144,14 +137,14 @@ class ReceiptTestCase(HomeserverTestCase): # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} ) ) # Test we get the latest event when we want both private and public receipts res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, receipt_type] + OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) ) self.assertEqual(res, {self.room_id1: event1_2_id}) @@ -164,7 +157,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test we get the latest event when we want only the public receipt res = self.get_success( - self.store.get_receipts_for_user(OUR_USER_ID, [receipt_type]) + self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ_PRIVATE]) ) self.assertEqual(res, {self.room_id1: event1_2_id}) @@ -187,20 +180,17 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {} + self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, receipt_type] + OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) ) self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id}) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_get_last_receipt_event_id_for_user(self, receipt_type: str) -> None: + def test_get_last_receipt_event_id_for_user(self) -> None: # Send some events into the first room event1_1_id = self.create_and_send_event( self.room_id1, UserID.from_string(OTHER_USER_ID) @@ -218,7 +208,7 @@ class ReceiptTestCase(HomeserverTestCase): # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} ) ) @@ -227,7 +217,7 @@ class ReceiptTestCase(HomeserverTestCase): self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, - [ReceiptTypes.READ, receipt_type], + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], ) ) self.assertEqual(res, event1_2_id) @@ -243,7 +233,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test we get the latest event when we want only the private receipt res = self.get_success( self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [receipt_type] + OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] ) ) self.assertEqual(res, event1_2_id) @@ -269,14 +259,14 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {} + self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id2, - [ReceiptTypes.READ, receipt_type], + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], ) ) self.assertEqual(res, event2_1_id) -- cgit 1.5.1 From 390b7ce946173b41a61f427ef25a9a1d0371ad0b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 1 Sep 2022 12:52:03 -0400 Subject: Disable calculating unread counts unless the config flag is enabled. (#13694) This avoids doing work that will never be used (since the resulting unread counts will never be sent in a /sync response). The negative of doing this is that unread counts will be incorrect when the feature is initially enabled. --- changelog.d/13694.bugfix | 1 + synapse/config/experimental.py | 3 +++ synapse/push/bulk_push_rule_evaluator.py | 7 +++++- tests/storage/test_event_push_actions.py | 42 +++++++++++++++----------------- 4 files changed, 30 insertions(+), 23 deletions(-) create mode 100644 changelog.d/13694.bugfix (limited to 'tests/storage') diff --git a/changelog.d/13694.bugfix b/changelog.d/13694.bugfix new file mode 100644 index 0000000000..48b9bb5f0a --- /dev/null +++ b/changelog.d/13694.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.20.0 that would cause the unstable unread counts from [MSC2654](https://github.com/matrix-org/matrix-spec-proposals/pull/2654) to be calculated even if the feature is disabled. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 260db49cad..702b81e636 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -71,6 +71,9 @@ class ExperimentalConfig(Config): self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False) # MSC2654: Unread counts + # + # Note that enabling this will result in an incorrect unread count for + # previously calculated push actions. self.msc2654_enabled: bool = experimental.get("msc2654_enabled", False) # MSC2815 (allow room moderators to view redacted event content) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index ccd512be54..d1caf8a0f7 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -262,7 +262,12 @@ class BulkPushRuleEvaluator: # This can happen due to out of band memberships return - count_as_unread = _should_count_as_unread(event, context) + # Disable counting as unread unless the experimental configuration is + # enabled, as it can cause additional (unwanted) rows to be added to the + # event_push_actions table. + count_as_unread = False + if self.hs.config.experimental.msc2654_enabled: + count_as_unread = _should_count_as_unread(event, context) rules_by_user = await self._get_rules_for_event(event) actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {} diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 62fd4aeb2f..fc43d7edd1 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -67,9 +67,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): last_event_id: str - def _assert_counts( - noitf_count: int, unread_count: int, highlight_count: int - ) -> None: + def _assert_counts(noitf_count: int, highlight_count: int) -> None: counts = self.get_success( self.store.db_pool.runInteraction( "get-unread-counts", @@ -82,7 +80,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): counts, NotifCounts( notify_count=noitf_count, - unread_count=unread_count, + unread_count=0, highlight_count=highlight_count, ), ) @@ -112,27 +110,27 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): ) ) - _assert_counts(0, 0, 0) + _assert_counts(0, 0) _create_event() - _assert_counts(1, 1, 0) + _assert_counts(1, 0) _rotate() - _assert_counts(1, 1, 0) + _assert_counts(1, 0) event_id = _create_event() - _assert_counts(2, 2, 0) + _assert_counts(2, 0) _rotate() - _assert_counts(2, 2, 0) + _assert_counts(2, 0) _create_event() _mark_read(event_id) - _assert_counts(1, 1, 0) + _assert_counts(1, 0) _mark_read(last_event_id) - _assert_counts(0, 0, 0) + _assert_counts(0, 0) _create_event() _rotate() - _assert_counts(1, 1, 0) + _assert_counts(1, 0) # Delete old event push actions, this should not affect the (summarised) count. # @@ -151,35 +149,35 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): ) ) self.assertEqual(result, []) - _assert_counts(1, 1, 0) + _assert_counts(1, 0) _mark_read(last_event_id) - _assert_counts(0, 0, 0) + _assert_counts(0, 0) event_id = _create_event(True) - _assert_counts(1, 1, 1) + _assert_counts(1, 1) _rotate() - _assert_counts(1, 1, 1) + _assert_counts(1, 1) # Check that adding another notification and rotating after highlight # works. _create_event() _rotate() - _assert_counts(2, 2, 1) + _assert_counts(2, 1) # Check that sending read receipts at different points results in the # right counts. _mark_read(event_id) - _assert_counts(1, 1, 0) + _assert_counts(1, 0) _mark_read(last_event_id) - _assert_counts(0, 0, 0) + _assert_counts(0, 0) _create_event(True) - _assert_counts(1, 1, 1) + _assert_counts(1, 1) _mark_read(last_event_id) - _assert_counts(0, 0, 0) + _assert_counts(0, 0) _rotate() - _assert_counts(0, 0, 0) + _assert_counts(0, 0) def test_find_first_stream_ordering_after_ts(self) -> None: def add_event(so: int, ts: int) -> None: -- cgit 1.5.1 From c2fe48a6ffb99f553f3eaecb8f15bcbedb58add0 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 7 Sep 2022 10:08:20 +0000 Subject: Rename the `EventFormatVersions` enum values so that they line up with room version numbers. (#13706) --- changelog.d/13706.misc | 1 + synapse/api/room_versions.py | 45 ++++++++++++---------- synapse/event_auth.py | 4 +- synapse/events/__init__.py | 12 +++--- synapse/events/builder.py | 4 +- synapse/events/validator.py | 2 +- synapse/federation/federation_base.py | 2 +- synapse/federation/federation_client.py | 2 +- synapse/storage/databases/main/event_federation.py | 2 +- synapse/storage/databases/main/events_worker.py | 6 +-- tests/storage/databases/main/test_events_worker.py | 2 +- tests/storage/test_event_federation.py | 2 +- tests/test_event_auth.py | 4 +- 13 files changed, 47 insertions(+), 41 deletions(-) create mode 100644 changelog.d/13706.misc (limited to 'tests/storage') diff --git a/changelog.d/13706.misc b/changelog.d/13706.misc new file mode 100644 index 0000000000..65c854c7a9 --- /dev/null +++ b/changelog.d/13706.misc @@ -0,0 +1 @@ +Rename the `EventFormatVersions` enum values so that they line up with room version numbers. \ No newline at end of file diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index a0e4ab6db6..e37acb0f1e 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -19,18 +19,23 @@ import attr class EventFormatVersions: """This is an internal enum for tracking the version of the event format, - independently from the room version. + independently of the room version. + + To reduce confusion, the event format versions are named after the room + versions that they were used or introduced in. + The concept of an 'event format version' is specific to Synapse (the + specification does not mention this term.) """ - V1 = 1 # $id:server event id format - V2 = 2 # MSC1659-style $hash event id format: introduced for room v3 - V3 = 3 # MSC1884-style $hash format: introduced for room v4 + ROOM_V1_V2 = 1 # $id:server event id format: used for room v1 and v2 + ROOM_V3 = 2 # MSC1659-style $hash event id format: used for room v3 + ROOM_V4_PLUS = 3 # MSC1884-style $hash format: introduced for room v4 KNOWN_EVENT_FORMAT_VERSIONS = { - EventFormatVersions.V1, - EventFormatVersions.V2, - EventFormatVersions.V3, + EventFormatVersions.ROOM_V1_V2, + EventFormatVersions.ROOM_V3, + EventFormatVersions.ROOM_V4_PLUS, } @@ -92,7 +97,7 @@ class RoomVersions: V1 = RoomVersion( "1", RoomDisposition.STABLE, - EventFormatVersions.V1, + EventFormatVersions.ROOM_V1_V2, StateResolutionVersions.V1, enforce_key_validity=False, special_case_aliases_auth=True, @@ -110,7 +115,7 @@ class RoomVersions: V2 = RoomVersion( "2", RoomDisposition.STABLE, - EventFormatVersions.V1, + EventFormatVersions.ROOM_V1_V2, StateResolutionVersions.V2, enforce_key_validity=False, special_case_aliases_auth=True, @@ -128,7 +133,7 @@ class RoomVersions: V3 = RoomVersion( "3", RoomDisposition.STABLE, - EventFormatVersions.V2, + EventFormatVersions.ROOM_V3, StateResolutionVersions.V2, enforce_key_validity=False, special_case_aliases_auth=True, @@ -146,7 +151,7 @@ class RoomVersions: V4 = RoomVersion( "4", RoomDisposition.STABLE, - EventFormatVersions.V3, + EventFormatVersions.ROOM_V4_PLUS, StateResolutionVersions.V2, enforce_key_validity=False, special_case_aliases_auth=True, @@ -164,7 +169,7 @@ class RoomVersions: V5 = RoomVersion( "5", RoomDisposition.STABLE, - EventFormatVersions.V3, + EventFormatVersions.ROOM_V4_PLUS, StateResolutionVersions.V2, enforce_key_validity=True, special_case_aliases_auth=True, @@ -182,7 +187,7 @@ class RoomVersions: V6 = RoomVersion( "6", RoomDisposition.STABLE, - EventFormatVersions.V3, + EventFormatVersions.ROOM_V4_PLUS, StateResolutionVersions.V2, enforce_key_validity=True, special_case_aliases_auth=False, @@ -200,7 +205,7 @@ class RoomVersions: MSC2176 = RoomVersion( "org.matrix.msc2176", RoomDisposition.UNSTABLE, - EventFormatVersions.V3, + EventFormatVersions.ROOM_V4_PLUS, StateResolutionVersions.V2, enforce_key_validity=True, special_case_aliases_auth=False, @@ -218,7 +223,7 @@ class RoomVersions: V7 = RoomVersion( "7", RoomDisposition.STABLE, - EventFormatVersions.V3, + EventFormatVersions.ROOM_V4_PLUS, StateResolutionVersions.V2, enforce_key_validity=True, special_case_aliases_auth=False, @@ -236,7 +241,7 @@ class RoomVersions: V8 = RoomVersion( "8", RoomDisposition.STABLE, - EventFormatVersions.V3, + EventFormatVersions.ROOM_V4_PLUS, StateResolutionVersions.V2, enforce_key_validity=True, special_case_aliases_auth=False, @@ -254,7 +259,7 @@ class RoomVersions: V9 = RoomVersion( "9", RoomDisposition.STABLE, - EventFormatVersions.V3, + EventFormatVersions.ROOM_V4_PLUS, StateResolutionVersions.V2, enforce_key_validity=True, special_case_aliases_auth=False, @@ -272,7 +277,7 @@ class RoomVersions: MSC3787 = RoomVersion( "org.matrix.msc3787", RoomDisposition.UNSTABLE, - EventFormatVersions.V3, + EventFormatVersions.ROOM_V4_PLUS, StateResolutionVersions.V2, enforce_key_validity=True, special_case_aliases_auth=False, @@ -290,7 +295,7 @@ class RoomVersions: V10 = RoomVersion( "10", RoomDisposition.STABLE, - EventFormatVersions.V3, + EventFormatVersions.ROOM_V4_PLUS, StateResolutionVersions.V2, enforce_key_validity=True, special_case_aliases_auth=False, @@ -308,7 +313,7 @@ class RoomVersions: MSC2716v4 = RoomVersion( "org.matrix.msc2716v4", RoomDisposition.UNSTABLE, - EventFormatVersions.V3, + EventFormatVersions.ROOM_V4_PLUS, StateResolutionVersions.V2, enforce_key_validity=True, special_case_aliases_auth=False, diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 389b0c5d53..c7d5ef92fc 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -109,7 +109,7 @@ def validate_event_for_room_version(event: "EventBase") -> None: if not is_invite_via_3pid: raise AuthError(403, "Event not signed by sender's server") - if event.format_version in (EventFormatVersions.V1,): + if event.format_version in (EventFormatVersions.ROOM_V1_V2,): # Only older room versions have event IDs to check. event_id_domain = get_domain_from_id(event.event_id) @@ -716,7 +716,7 @@ def check_redaction( if user_level >= redact_level: return False - if room_version_obj.event_format == EventFormatVersions.V1: + if room_version_obj.event_format == EventFormatVersions.ROOM_V1_V2: redacter_domain = get_domain_from_id(event.event_id) if not isinstance(event.redacts, str): return False diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 39ad2793d9..b2c9119fd0 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -442,7 +442,7 @@ class EventBase(metaclass=abc.ABCMeta): class FrozenEvent(EventBase): - format_version = EventFormatVersions.V1 # All events of this type are V1 + format_version = EventFormatVersions.ROOM_V1_V2 # All events of this type are V1 def __init__( self, @@ -490,7 +490,7 @@ class FrozenEvent(EventBase): class FrozenEventV2(EventBase): - format_version = EventFormatVersions.V2 # All events of this type are V2 + format_version = EventFormatVersions.ROOM_V3 # All events of this type are V2 def __init__( self, @@ -567,7 +567,7 @@ class FrozenEventV2(EventBase): class FrozenEventV3(FrozenEventV2): """FrozenEventV3, which differs from FrozenEventV2 only in the event_id format""" - format_version = EventFormatVersions.V3 # All events of this type are V3 + format_version = EventFormatVersions.ROOM_V4_PLUS # All events of this type are V3 @property def event_id(self) -> str: @@ -597,11 +597,11 @@ def _event_type_from_format_version( `FrozenEvent` """ - if format_version == EventFormatVersions.V1: + if format_version == EventFormatVersions.ROOM_V1_V2: return FrozenEvent - elif format_version == EventFormatVersions.V2: + elif format_version == EventFormatVersions.ROOM_V3: return FrozenEventV2 - elif format_version == EventFormatVersions.V3: + elif format_version == EventFormatVersions.ROOM_V4_PLUS: return FrozenEventV3 else: raise Exception("No event format %r" % (format_version,)) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 17f624b68f..746bd3978d 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -137,7 +137,7 @@ class EventBuilder: # The types of auth/prev events changes between event versions. prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] - if format_version == EventFormatVersions.V1: + if format_version == EventFormatVersions.ROOM_V1_V2: auth_events = await self._store.add_event_hashes(auth_event_ids) prev_events = await self._store.add_event_hashes(prev_event_ids) else: @@ -253,7 +253,7 @@ def create_local_event_from_event_dict( time_now = int(clock.time_msec()) - if format_version == EventFormatVersions.V1: + if format_version == EventFormatVersions.ROOM_V1_V2: event_dict["event_id"] = _create_event_id(clock, hostname) event_dict["origin"] = hostname diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 27c8beba25..a6f0104396 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -45,7 +45,7 @@ class EventValidator: """ self.validate_builder(event) - if event.format_version == EventFormatVersions.V1: + if event.format_version == EventFormatVersions.ROOM_V1_V2: EventID.from_string(event.event_id) required = [ diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 4269a98db2..abe2c1971a 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -194,7 +194,7 @@ async def _check_sigs_on_pdu( # event id's domain (normally only the case for joins/leaves), and add additional # checks. Only do this if the room version has a concept of event ID domain # (ie, the room version uses old-style non-hash event IDs). - if room_version.event_format == EventFormatVersions.V1: + if room_version.event_format == EventFormatVersions.ROOM_V1_V2: event_domain = get_domain_from_id(pdu.event_id) if event_domain != sender_domain: try: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 7ee2974bb1..4a4289ee7c 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1190,7 +1190,7 @@ class FederationClient(FederationBase): # Otherwise, consider it a legitimate error and raise. err = e.to_synapse_error() if self._is_unknown_endpoint(e, err): - if room_version.event_format != EventFormatVersions.V1: + if room_version.event_format != EventFormatVersions.ROOM_V1_V2: raise SynapseError( 400, "User's homeserver does not support this room version", diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index c836078da6..e687f87eca 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1606,7 +1606,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas logger.info("Invalid prev_events for %s", event_id) continue - if room_version.event_format == EventFormatVersions.V1: + if room_version.event_format == EventFormatVersions.ROOM_V1_V2: for prev_event_tuple in prev_events: if ( not isinstance(prev_event_tuple, list) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 9b997c304d..84f17a9945 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1156,7 +1156,7 @@ class EventsWorkerStore(SQLBaseStore): if format_version is None: # This means that we stored the event before we had the concept # of a event format version, so it must be a V1 event. - format_version = EventFormatVersions.V1 + format_version = EventFormatVersions.ROOM_V1_V2 room_version_id = row.room_version_id @@ -1186,10 +1186,10 @@ class EventsWorkerStore(SQLBaseStore): # # So, the following approximations should be adequate. - if format_version == EventFormatVersions.V1: + if format_version == EventFormatVersions.ROOM_V1_V2: # if it's event format v1 then it must be room v1 or v2 room_version = RoomVersions.V1 - elif format_version == EventFormatVersions.V2: + elif format_version == EventFormatVersions.ROOM_V3: # if it's event format v2 then it must be room v3 room_version = RoomVersions.V3 else: diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 46d829b062..67401272ac 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -254,7 +254,7 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase): "room_id": self.room_id, "json": json.dumps(event_json), "internal_metadata": "{}", - "format_version": EventFormatVersions.V3, + "format_version": EventFormatVersions.ROOM_V4_PLUS, }, ) ) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index d92a9ac5b7..a6679e1312 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -513,7 +513,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def prev_event_format(prev_event_id: str) -> Union[Tuple[str, dict], str]: """Account for differences in prev_events format across room versions""" - if room_version.event_format == EventFormatVersions.V1: + if room_version.event_format == EventFormatVersions.ROOM_V1_V2: return prev_event_id, {} return prev_event_id diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index e42d7b9ba0..f4d9fba0a1 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -821,7 +821,7 @@ def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase: def _build_auth_dict_for_room_version( room_version: RoomVersion, auth_events: Iterable[EventBase] ) -> List: - if room_version.event_format == EventFormatVersions.V1: + if room_version.event_format == EventFormatVersions.ROOM_V1_V2: return [(e.event_id, "not_used") for e in auth_events] else: return [e.event_id for e in auth_events] @@ -871,7 +871,7 @@ event_count = 0 def _maybe_get_event_id_dict_for_room_version(room_version: RoomVersion) -> dict: """If this room version needs it, generate an event id""" - if room_version.event_format != EventFormatVersions.V1: + if room_version.event_format != EventFormatVersions.ROOM_V1_V2: return {} global event_count -- cgit 1.5.1 From f799eac7ea96f943ad1272a5a81f845dfa08a254 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 8 Sep 2022 17:41:48 +0200 Subject: Add timestamp to user's consent (#13741) Co-authored-by: reivilibre --- changelog.d/13741.feature | 1 + docs/admin_api/user_admin_api.md | 2 ++ synapse/handlers/admin.py | 1 + synapse/storage/databases/main/registration.py | 6 +++- .../main/delta/72/06add_consent_ts_to_users.sql | 16 +++++++++++ tests/rest/admin/test_user.py | 1 + tests/storage/test_registration.py | 33 +++++++++++++++++----- 7 files changed, 52 insertions(+), 8 deletions(-) create mode 100644 changelog.d/13741.feature create mode 100644 synapse/storage/schema/main/delta/72/06add_consent_ts_to_users.sql (limited to 'tests/storage') diff --git a/changelog.d/13741.feature b/changelog.d/13741.feature new file mode 100644 index 0000000000..dff46f373f --- /dev/null +++ b/changelog.d/13741.feature @@ -0,0 +1 @@ +Document the timestamp when a user accepts the consent, if [consent tracking](https://matrix-org.github.io/synapse/latest/consent_tracking.html) is used. \ No newline at end of file diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index c1ca0c8a64..975f05c929 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -42,6 +42,7 @@ It returns a JSON body like the following: "appservice_id": null, "consent_server_notice_sent": null, "consent_version": null, + "consent_ts": null, "external_ids": [ { "auth_provider": "", @@ -364,6 +365,7 @@ The following actions are **NOT** performed. The list may be incomplete. - Remove the user's creation (registration) timestamp - [Remove rate limit overrides](#override-ratelimiting-for-users) - Remove from monthly active users +- Remove user's consent information (consent version and timestamp) ## Reset password diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index d4fe7df533..cf9f19608a 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -70,6 +70,7 @@ class AdminHandler: "appservice_id", "consent_server_notice_sent", "consent_version", + "consent_ts", "user_type", "is_guest", } diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 7fb9c801da..ac821878b0 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -175,6 +175,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "is_guest", "admin", "consent_version", + "consent_ts", "consent_server_notice_sent", "appservice_id", "creation_ts", @@ -2227,7 +2228,10 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): txn, table="users", keyvalues={"name": user_id}, - updatevalues={"consent_version": consent_version}, + updatevalues={ + "consent_version": consent_version, + "consent_ts": self._clock.time_msec(), + }, ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) diff --git a/synapse/storage/schema/main/delta/72/06add_consent_ts_to_users.sql b/synapse/storage/schema/main/delta/72/06add_consent_ts_to_users.sql new file mode 100644 index 0000000000..609eb1750f --- /dev/null +++ b/synapse/storage/schema/main/delta/72/06add_consent_ts_to_users.sql @@ -0,0 +1,16 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE users ADD consent_ts bigint; diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 1afd082707..ec5ccf6fca 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2580,6 +2580,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertIn("appservice_id", content) self.assertIn("consent_server_notice_sent", content) self.assertIn("consent_version", content) + self.assertIn("consent_ts", content) self.assertIn("external_ids", content) # This key was removed intentionally. Ensure it is not accidentally re-included. diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index a49ac1525e..853a93afab 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -11,15 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import UserTypes from synapse.api.errors import ThreepidValidationError +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase class RegistrationStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id = "@my-user:test" @@ -27,7 +30,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): self.pwhash = "{xx1}123456789" self.device_id = "akgjhdjklgshg" - def test_register(self): + def test_register(self) -> None: self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.assertEqual( @@ -38,6 +41,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): "admin": 0, "is_guest": 0, "consent_version": None, + "consent_ts": None, "consent_server_notice_sent": None, "appservice_id": None, "creation_ts": 0, @@ -48,7 +52,20 @@ class RegistrationStoreTestCase(HomeserverTestCase): (self.get_success(self.store.get_user_by_id(self.user_id))), ) - def test_add_tokens(self): + def test_consent(self) -> None: + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + before_consent = self.clock.time_msec() + self.reactor.advance(5) + self.get_success(self.store.user_set_consent_version(self.user_id, "1")) + self.reactor.advance(5) + + user = self.get_success(self.store.get_user_by_id(self.user_id)) + assert user + self.assertEqual(user["consent_version"], "1") + self.assertGreater(user["consent_ts"], before_consent) + self.assertLess(user["consent_ts"], self.clock.time_msec()) + + def test_add_tokens(self) -> None: self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.get_success( self.store.add_access_token_to_user( @@ -58,11 +75,12 @@ class RegistrationStoreTestCase(HomeserverTestCase): result = self.get_success(self.store.get_user_by_access_token(self.tokens[1])) + assert result self.assertEqual(result.user_id, self.user_id) self.assertEqual(result.device_id, self.device_id) self.assertIsNotNone(result.token_id) - def test_user_delete_access_tokens(self): + def test_user_delete_access_tokens(self) -> None: # add some tokens self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.get_success( @@ -87,6 +105,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): # check the one not associated with the device was not deleted user = self.get_success(self.store.get_user_by_access_token(self.tokens[0])) + assert user self.assertEqual(self.user_id, user.user_id) # now delete the rest @@ -95,11 +114,11 @@ class RegistrationStoreTestCase(HomeserverTestCase): user = self.get_success(self.store.get_user_by_access_token(self.tokens[0])) self.assertIsNone(user, "access token was not deleted without device_id") - def test_is_support_user(self): + def test_is_support_user(self) -> None: TEST_USER = "@test:test" SUPPORT_USER = "@support:test" - res = self.get_success(self.store.is_support_user(None)) + res = self.get_success(self.store.is_support_user(None)) # type: ignore[arg-type] self.assertFalse(res) self.get_success( self.store.register_user(user_id=TEST_USER, password_hash=None) @@ -115,7 +134,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): res = self.get_success(self.store.is_support_user(SUPPORT_USER)) self.assertTrue(res) - def test_3pid_inhibit_invalid_validation_session_error(self): + def test_3pid_inhibit_invalid_validation_session_error(self) -> None: """Tests that enabling the configuration option to inhibit 3PID errors on /requestToken also inhibits validation errors caused by an unknown session ID. """ -- cgit 1.5.1 From f2d2481e56f06005de5ae8429eca3bb31834079e Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 9 Sep 2022 11:14:10 +0100 Subject: Require SQLite >= 3.27.0 (#13760) --- changelog.d/13760.removal | 1 + synapse/storage/database.py | 47 +++++----- synapse/storage/databases/main/lock.py | 121 ++++++++----------------- synapse/storage/databases/main/stats.py | 86 +++++++----------- synapse/storage/databases/main/transactions.py | 30 ++---- synapse/storage/engines/_base.py | 8 -- synapse/storage/engines/postgres.py | 7 -- synapse/storage/engines/sqlite.py | 13 +-- tests/storage/test_base.py | 1 - 9 files changed, 106 insertions(+), 208 deletions(-) create mode 100644 changelog.d/13760.removal (limited to 'tests/storage') diff --git a/changelog.d/13760.removal b/changelog.d/13760.removal new file mode 100644 index 0000000000..624e7c3678 --- /dev/null +++ b/changelog.d/13760.removal @@ -0,0 +1 @@ +Synapse will now refuse to start if configured to use SQLite < 3.27. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index b394a6658b..e881bff7fb 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -533,15 +533,14 @@ class DatabasePool: if isinstance(self.engine, Sqlite3Engine): self._unsafe_to_upsert_tables.add("user_directory_search") - if self.engine.can_native_upsert: - # Check ASAP (and then later, every 1s) to see if we have finished - # background updates of tables that aren't safe to update. - self._clock.call_later( - 0.0, - run_as_background_process, - "upsert_safety_check", - self._check_safe_to_upsert, - ) + # Check ASAP (and then later, every 1s) to see if we have finished + # background updates of tables that aren't safe to update. + self._clock.call_later( + 0.0, + run_as_background_process, + "upsert_safety_check", + self._check_safe_to_upsert, + ) def name(self) -> str: "Return the name of this database" @@ -1160,11 +1159,8 @@ class DatabasePool: attempts = 0 while True: try: - # We can autocommit if we are going to use native upserts - autocommit = ( - self.engine.can_native_upsert - and table not in self._unsafe_to_upsert_tables - ) + # We can autocommit if it is safe to upsert + autocommit = table not in self._unsafe_to_upsert_tables return await self.runInteraction( desc, @@ -1199,7 +1195,7 @@ class DatabasePool: ) -> bool: """ Pick the UPSERT method which works best on the platform. Either the - native one (Pg9.5+, recent SQLites), or fall back to an emulated method. + native one (Pg9.5+, SQLite >= 3.24), or fall back to an emulated method. Args: txn: The transaction to use. @@ -1207,14 +1203,15 @@ class DatabasePool: keyvalues: The unique key tables and their new values values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting - lock: True to lock the table when doing the upsert. + lock: True to lock the table when doing the upsert. Unused when performing + a native upsert. Returns: Returns True if a row was inserted or updated (i.e. if `values` is not empty then this always returns True) """ insertion_values = insertion_values or {} - if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: + if table not in self._unsafe_to_upsert_tables: return self.simple_upsert_txn_native_upsert( txn, table, keyvalues, values, insertion_values=insertion_values ) @@ -1365,14 +1362,12 @@ class DatabasePool: value_names: The value column names value_values: A list of each row's value column values. Ignored if value_names is empty. - lock: True to lock the table when doing the upsert. Unused if the database engine - supports native upserts. + lock: True to lock the table when doing the upsert. Unused when performing + a native upsert. """ - # We can autocommit if we are going to use native upserts - autocommit = ( - self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables - ) + # We can autocommit if it safe to upsert + autocommit = table not in self._unsafe_to_upsert_tables await self.runInteraction( desc, @@ -1406,10 +1401,10 @@ class DatabasePool: value_names: The value column names value_values: A list of each row's value column values. Ignored if value_names is empty. - lock: True to lock the table when doing the upsert. Unused if the database engine - supports native upserts. + lock: True to lock the table when doing the upsert. Unused when performing + a native upsert. """ - if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: + if table not in self._unsafe_to_upsert_tables: return self.simple_upsert_many_txn_native_upsert( txn, table, key_names, key_values, value_names, value_values ) diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index 2d7633fbd5..7270ef09da 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -129,91 +129,48 @@ class LockStore(SQLBaseStore): now = self._clock.time_msec() token = random_string(6) - if self.db_pool.engine.can_native_upsert: - - def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool: - # We take out the lock if either a) there is no row for the lock - # already, b) the existing row has timed out, or c) the row is - # for this instance (which means the process got killed and - # restarted) - sql = """ - INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts) - VALUES (?, ?, ?, ?, ?) - ON CONFLICT (lock_name, lock_key) - DO UPDATE - SET - token = EXCLUDED.token, - instance_name = EXCLUDED.instance_name, - last_renewed_ts = EXCLUDED.last_renewed_ts - WHERE - worker_locks.last_renewed_ts < ? - OR worker_locks.instance_name = EXCLUDED.instance_name - """ - txn.execute( - sql, - ( - lock_name, - lock_key, - self._instance_name, - token, - now, - now - _LOCK_TIMEOUT_MS, - ), - ) - - # We only acquired the lock if we inserted or updated the table. - return bool(txn.rowcount) - - did_lock = await self.db_pool.runInteraction( - "try_acquire_lock", - _try_acquire_lock_txn, - # We can autocommit here as we're executing a single query, this - # will avoid serialization errors. - db_autocommit=True, + def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool: + # We take out the lock if either a) there is no row for the lock + # already, b) the existing row has timed out, or c) the row is + # for this instance (which means the process got killed and + # restarted) + sql = """ + INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (lock_name, lock_key) + DO UPDATE + SET + token = EXCLUDED.token, + instance_name = EXCLUDED.instance_name, + last_renewed_ts = EXCLUDED.last_renewed_ts + WHERE + worker_locks.last_renewed_ts < ? + OR worker_locks.instance_name = EXCLUDED.instance_name + """ + txn.execute( + sql, + ( + lock_name, + lock_key, + self._instance_name, + token, + now, + now - _LOCK_TIMEOUT_MS, + ), ) - if not did_lock: - return None - - else: - # If we're on an old SQLite we emulate the above logic by first - # clearing out any existing stale locks and then upserting. - - def _try_acquire_lock_emulated_txn(txn: LoggingTransaction) -> bool: - sql = """ - DELETE FROM worker_locks - WHERE - lock_name = ? - AND lock_key = ? - AND (last_renewed_ts < ? OR instance_name = ?) - """ - txn.execute( - sql, - (lock_name, lock_key, now - _LOCK_TIMEOUT_MS, self._instance_name), - ) - - inserted = self.db_pool.simple_upsert_txn_emulated( - txn, - table="worker_locks", - keyvalues={ - "lock_name": lock_name, - "lock_key": lock_key, - }, - values={}, - insertion_values={ - "token": token, - "last_renewed_ts": self._clock.time_msec(), - "instance_name": self._instance_name, - }, - ) - - return inserted - did_lock = await self.db_pool.runInteraction( - "try_acquire_lock_emulated", _try_acquire_lock_emulated_txn - ) + # We only acquired the lock if we inserted or updated the table. + return bool(txn.rowcount) - if not did_lock: - return None + did_lock = await self.db_pool.runInteraction( + "try_acquire_lock", + _try_acquire_lock_txn, + # We can autocommit here as we're executing a single query, this + # will avoid serialization errors. + db_autocommit=True, + ) + if not did_lock: + return None lock = Lock( self._reactor, diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index b4c652acf3..356d4ca788 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -446,59 +446,41 @@ class StatsStore(StateDeltasStore): absolutes: Absolute (set) fields additive_relatives: Fields that will be added onto if existing row present. """ - if self.database_engine.can_native_upsert: - absolute_updates = [ - "%(field)s = EXCLUDED.%(field)s" % {"field": field} - for field in absolutes.keys() - ] - - relative_updates = [ - "%(field)s = EXCLUDED.%(field)s + COALESCE(%(table)s.%(field)s, 0)" - % {"table": table, "field": field} - for field in additive_relatives.keys() - ] - - insert_cols = [] - qargs = [] - - for (key, val) in chain( - keyvalues.items(), absolutes.items(), additive_relatives.items() - ): - insert_cols.append(key) - qargs.append(val) + absolute_updates = [ + "%(field)s = EXCLUDED.%(field)s" % {"field": field} + for field in absolutes.keys() + ] + + relative_updates = [ + "%(field)s = EXCLUDED.%(field)s + COALESCE(%(table)s.%(field)s, 0)" + % {"table": table, "field": field} + for field in additive_relatives.keys() + ] + + insert_cols = [] + qargs = [] + + for (key, val) in chain( + keyvalues.items(), absolutes.items(), additive_relatives.items() + ): + insert_cols.append(key) + qargs.append(val) + + sql = """ + INSERT INTO %(table)s (%(insert_cols_cs)s) + VALUES (%(insert_vals_qs)s) + ON CONFLICT (%(key_columns)s) DO UPDATE SET %(updates)s + """ % { + "table": table, + "insert_cols_cs": ", ".join(insert_cols), + "insert_vals_qs": ", ".join( + ["?"] * (len(keyvalues) + len(absolutes) + len(additive_relatives)) + ), + "key_columns": ", ".join(keyvalues), + "updates": ", ".join(chain(absolute_updates, relative_updates)), + } - sql = """ - INSERT INTO %(table)s (%(insert_cols_cs)s) - VALUES (%(insert_vals_qs)s) - ON CONFLICT (%(key_columns)s) DO UPDATE SET %(updates)s - """ % { - "table": table, - "insert_cols_cs": ", ".join(insert_cols), - "insert_vals_qs": ", ".join( - ["?"] * (len(keyvalues) + len(absolutes) + len(additive_relatives)) - ), - "key_columns": ", ".join(keyvalues), - "updates": ", ".join(chain(absolute_updates, relative_updates)), - } - - txn.execute(sql, qargs) - else: - self.database_engine.lock_table(txn, table) - retcols = list(chain(absolutes.keys(), additive_relatives.keys())) - current_row = self.db_pool.simple_select_one_txn( - txn, table, keyvalues, retcols, allow_none=True - ) - if current_row is None: - merged_dict = {**keyvalues, **absolutes, **additive_relatives} - self.db_pool.simple_insert_txn(txn, table, merged_dict) - else: - for (key, val) in additive_relatives.items(): - if current_row[key] is None: - current_row[key] = val - else: - current_row[key] += val - current_row.update(absolutes) - self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row) + txn.execute(sql, qargs) async def _calculate_and_set_initial_state_for_room(self, room_id: str) -> None: """Calculate and insert an entry into room_stats_current. diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index ba79e19f7f..f8c6877ee8 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -221,25 +221,15 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): retry_interval: how long until next retry in ms """ - if self.database_engine.can_native_upsert: - await self.db_pool.runInteraction( - "set_destination_retry_timings", - self._set_destination_retry_timings_native, - destination, - failure_ts, - retry_last_ts, - retry_interval, - db_autocommit=True, # Safe as its a single upsert - ) - else: - await self.db_pool.runInteraction( - "set_destination_retry_timings", - self._set_destination_retry_timings_emulated, - destination, - failure_ts, - retry_last_ts, - retry_interval, - ) + await self.db_pool.runInteraction( + "set_destination_retry_timings", + self._set_destination_retry_timings_native, + destination, + failure_ts, + retry_last_ts, + retry_interval, + db_autocommit=True, # Safe as it's a single upsert + ) def _set_destination_retry_timings_native( self, @@ -249,8 +239,6 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): retry_last_ts: int, retry_interval: int, ) -> None: - assert self.database_engine.can_native_upsert - # Upsert retry time interval if retry_interval is zero (i.e. we're # resetting it) or greater than the existing retry interval. # diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index 971ff82693..0d16a419a4 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -43,14 +43,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): def single_threaded(self) -> bool: ... - @property - @abc.abstractmethod - def can_native_upsert(self) -> bool: - """ - Do we support native UPSERTs? - """ - ... - @property @abc.abstractmethod def supports_using_any_list(self) -> bool: diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 517f9d5f98..7f7d006ac2 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -158,13 +158,6 @@ class PostgresEngine(BaseDatabaseEngine[psycopg2.extensions.connection]): cursor.close() db_conn.commit() - @property - def can_native_upsert(self) -> bool: - """ - Can we use native UPSERTs? - """ - return True - @property def supports_using_any_list(self) -> bool: """Do we support using `a = ANY(?)` and passing a list""" diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 621f2c5efe..095ae0a096 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -48,14 +48,6 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]): def single_threaded(self) -> bool: return True - @property - def can_native_upsert(self) -> bool: - """ - Do we support native UPSERTs? This requires SQLite3 3.24+, plus some - more work we haven't done yet to tell what was inserted vs updated. - """ - return sqlite3.sqlite_version_info >= (3, 24, 0) - @property def supports_using_any_list(self) -> bool: """Do we support using `a = ANY(?)` and passing a list""" @@ -70,12 +62,11 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]): self, db_conn: sqlite3.Connection, allow_outdated_version: bool = False ) -> None: if not allow_outdated_version: - version = sqlite3.sqlite_version_info # Synapse is untested against older SQLite versions, and we don't want # to let users upgrade to a version of Synapse with broken support for their # sqlite version, because it risks leaving them with a half-upgraded db. - if version < (3, 22, 0): - raise RuntimeError("Synapse requires sqlite 3.22 or above.") + if sqlite3.sqlite_version_info < (3, 27, 0): + raise RuntimeError("Synapse requires sqlite 3.27 or above.") def check_new_database(self, txn: Cursor) -> None: """Gets called when setting up a brand new database. This allows us to diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index cce8e75c74..40e58f8199 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -54,7 +54,6 @@ class SQLBaseStoreTestCase(unittest.TestCase): sqlite_config = {"name": "sqlite3"} engine = create_engine(sqlite_config) fake_engine = Mock(wraps=engine) - fake_engine.can_native_upsert = False fake_engine.in_transaction.return_value = False db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine) -- cgit 1.5.1 From efd108b45d1706526416bc9a6f89463b5ff4506a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 23 Sep 2022 10:33:28 -0400 Subject: Accept & store thread IDs for receipts (implement MSC3771). (#13782) Updates the `/receipts` endpoint and receipt EDU handler to parse a `thread_id` from the body and insert it in the database. --- changelog.d/13782.feature | 1 + synapse/config/experimental.py | 2 + synapse/handlers/receipts.py | 23 ++++++- synapse/replication/tcp/client.py | 3 +- synapse/replication/tcp/streams/_base.py | 1 + synapse/rest/client/read_marker.py | 2 + synapse/rest/client/receipts.py | 14 ++++- synapse/rest/client/versions.py | 2 + synapse/storage/database.py | 2 + synapse/storage/databases/main/receipts.py | 87 +++++++++++++++++++------- synapse/types.py | 1 + tests/federation/test_federation_sender.py | 21 ++++++- tests/handlers/test_appservice.py | 1 + tests/replication/slave/storage/test_events.py | 2 +- tests/replication/tcp/streams/test_receipts.py | 15 ++++- tests/storage/test_event_push_actions.py | 1 + tests/storage/test_receipts.py | 36 ++++++++--- 17 files changed, 173 insertions(+), 41 deletions(-) create mode 100644 changelog.d/13782.feature (limited to 'tests/storage') diff --git a/changelog.d/13782.feature b/changelog.d/13782.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13782.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 595eb007a5..933779c23a 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -83,6 +83,8 @@ class ExperimentalConfig(Config): # MSC3786 (Add a default push rule to ignore m.room.server_acl events) self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False) + # MSC3771: Thread read receipts + self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False) # MSC3772: A push rule for mutual relations. self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index afaf3261df..4768a34c07 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -63,6 +63,8 @@ class ReceiptsHandler: self.clock = self.hs.get_clock() self.state = hs.get_state_handler() + self._msc3771_enabled = hs.config.experimental.msc3771_enabled + async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None: """Called when we receive an EDU of type m.receipt from a remote HS.""" receipts = [] @@ -91,13 +93,23 @@ class ReceiptsHandler: ) continue + # Check if these receipts apply to a thread. + thread_id = None + data = user_values.get("data", {}) + if self._msc3771_enabled and isinstance(data, dict): + thread_id = data.get("thread_id") + # If the thread ID is invalid, consider it missing. + if not isinstance(thread_id, str): + thread_id = None + receipts.append( ReadReceipt( room_id=room_id, receipt_type=receipt_type, user_id=user_id, event_ids=user_values["event_ids"], - data=user_values.get("data", {}), + thread_id=thread_id, + data=data, ) ) @@ -114,6 +126,7 @@ class ReceiptsHandler: receipt.receipt_type, receipt.user_id, receipt.event_ids, + receipt.thread_id, receipt.data, ) @@ -146,7 +159,12 @@ class ReceiptsHandler: return True async def received_client_receipt( - self, room_id: str, receipt_type: str, user_id: str, event_id: str + self, + room_id: str, + receipt_type: str, + user_id: str, + event_id: str, + thread_id: Optional[str], ) -> None: """Called when a client tells us a local user has read up to the given event_id in the room. @@ -156,6 +174,7 @@ class ReceiptsHandler: receipt_type=receipt_type, user_id=user_id, event_ids=[event_id], + thread_id=thread_id, data={"ts": int(self.clock.time_msec())}, ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index cf9cd6833b..b2522f98ca 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -427,7 +427,8 @@ class FederationSenderHandler: receipt.receipt_type, receipt.user_id, [receipt.event_id], - receipt.data, + thread_id=receipt.thread_id, + data=receipt.data, ) await self.federation_sender.send_read_receipt(receipt_info) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 398bebeaa6..e01155ad59 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -361,6 +361,7 @@ class ReceiptsStream(Stream): receipt_type: str user_id: str event_id: str + thread_id: Optional[str] data: dict NAME = "receipts" diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 5e53096539..852838515c 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -83,6 +83,8 @@ class ReadMarkerRestServlet(RestServlet): receipt_type, user_id=requester.user.to_string(), event_id=event_id, + # Setting the thread ID is not possible with the /read_markers endpoint. + thread_id=None, ) return 200, {} diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 5b7fad7402..f3ff156abe 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -49,6 +49,7 @@ class ReceiptRestServlet(RestServlet): ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ, } + self._msc3771_enabled = hs.config.experimental.msc3771_enabled async def on_POST( self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str @@ -61,7 +62,17 @@ class ReceiptRestServlet(RestServlet): f"Receipt type must be {', '.join(self._known_receipt_types)}", ) - parse_json_object_from_request(request, allow_empty_body=False) + body = parse_json_object_from_request(request) + + # Pull the thread ID, if one exists. + thread_id = None + if self._msc3771_enabled: + if "thread_id" in body: + thread_id = body.get("thread_id") + if not thread_id or not isinstance(thread_id, str): + raise SynapseError( + 400, "thread_id field must be a non-empty string" + ) await self.presence_handler.bump_presence_active_time(requester.user) @@ -77,6 +88,7 @@ class ReceiptRestServlet(RestServlet): receipt_type, user_id=requester.user.to_string(), event_id=event_id, + thread_id=thread_id, ) return 200, {} diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index b3917a5abc..c95b0d6f19 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -103,6 +103,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above + # Support for thread read receipts. + "org.matrix.msc3771": self.config.experimental.msc3771_enabled, # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, # Adds support for login token requests as per MSC3882 diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 921cd4dc5e..9d116f6925 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -95,6 +95,8 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx", "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx", "event_push_summary": "event_push_summary_unique_index", + "receipts_linearized": "receipts_linearized_unique_index", + "receipts_graph": "receipts_graph_unique_index", } diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index ddb8e80b69..52fe0db924 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -540,7 +540,9 @@ class ReceiptsWorkerStore(SQLBaseStore): async def get_all_updated_receipts( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, list]], int, bool]: + ) -> Tuple[ + List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], int, bool + ]: """Get updates for receipts replication stream. Args: @@ -567,9 +569,13 @@ class ReceiptsWorkerStore(SQLBaseStore): def get_all_updated_receipts_txn( txn: LoggingTransaction, - ) -> Tuple[List[Tuple[int, list]], int, bool]: + ) -> Tuple[ + List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], + int, + bool, + ]: sql = """ - SELECT stream_id, room_id, receipt_type, user_id, event_id, data + SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC @@ -578,8 +584,8 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id, limit)) updates = cast( - List[Tuple[int, list]], - [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn], + List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], + [(r[0], r[1:6] + (db_to_json(r[6]),)) for r in txn], ) limited = False @@ -631,6 +637,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_id: str, + thread_id: Optional[str], data: JsonDict, stream_id: int, ) -> Optional[int]: @@ -657,12 +664,27 @@ class ReceiptsWorkerStore(SQLBaseStore): # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts if stream_ordering is not None: - sql = ( - "SELECT stream_ordering, event_id FROM events" - " INNER JOIN receipts_linearized AS r USING (event_id, room_id)" - " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" + if thread_id is None: + thread_clause = "r.thread_id IS NULL" + thread_args: Tuple[str, ...] = () + else: + thread_clause = "r.thread_id = ?" + thread_args = (thread_id,) + + sql = f""" + SELECT stream_ordering, event_id FROM events + INNER JOIN receipts_linearized AS r USING (event_id, room_id) + WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ? AND {thread_clause} + """ + txn.execute( + sql, + ( + room_id, + receipt_type, + user_id, + ) + + thread_args, ) - txn.execute(sql, (room_id, receipt_type, user_id)) for so, eid in txn: if int(so) >= stream_ordering: @@ -682,21 +704,28 @@ class ReceiptsWorkerStore(SQLBaseStore): self._receipts_stream_cache.entity_has_changed, room_id, stream_id ) + keyvalues = { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + } + where_clause = "" + if thread_id is None: + where_clause = "thread_id IS NULL" + else: + keyvalues["thread_id"] = thread_id + self.db_pool.simple_upsert_txn( txn, table="receipts_linearized", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, + keyvalues=keyvalues, values={ "stream_id": stream_id, "event_id": event_id, "event_stream_ordering": stream_ordering, "data": json_encoder.encode(data), - "thread_id": None, }, + where_clause=where_clause, # receipts_linearized has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock lock=False, @@ -748,6 +777,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_ids: List[str], + thread_id: Optional[str], data: dict, ) -> Optional[Tuple[int, int]]: """Insert a receipt, either from local client or remote server. @@ -780,6 +810,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type, user_id, linearized_event_id, + thread_id, data, stream_id=stream_id, # Read committed is actually beneficial here because we check for a receipt with @@ -794,7 +825,8 @@ class ReceiptsWorkerStore(SQLBaseStore): now = self._clock.time_msec() logger.debug( - "RR for event %s in %s (%i ms old)", + "Receipt %s for event %s in %s (%i ms old)", + receipt_type, linearized_event_id, room_id, now - event_ts, @@ -807,6 +839,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type, user_id, event_ids, + thread_id, data, ) @@ -821,6 +854,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_ids: List[str], + thread_id: Optional[str], data: JsonDict, ) -> None: assert self._can_write_to_receipts @@ -832,19 +866,26 @@ class ReceiptsWorkerStore(SQLBaseStore): # FIXME: This shouldn't invalidate the whole cache txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,)) + keyvalues = { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + } + where_clause = "" + if thread_id is None: + where_clause = "thread_id IS NULL" + else: + keyvalues["thread_id"] = thread_id + self.db_pool.simple_upsert_txn( txn, table="receipts_graph", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, + keyvalues=keyvalues, values={ "event_ids": json_encoder.encode(event_ids), "data": json_encoder.encode(data), - "thread_id": None, }, + where_clause=where_clause, # receipts_graph has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock lock=False, diff --git a/synapse/types.py b/synapse/types.py index ec44601f54..773f0438d5 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -835,6 +835,7 @@ class ReadReceipt: receipt_type: str user_id: str event_ids: List[str] + thread_id: Optional[str] data: JsonDict diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index a5aa500ef8..f1e357764f 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -49,7 +49,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): sender = self.hs.get_federation_sender() receipt = ReadReceipt( - "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + "room_id", + "m.read", + "user_id", + ["event_id"], + thread_id=None, + data={"ts": 1234}, ) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) @@ -89,7 +94,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): sender = self.hs.get_federation_sender() receipt = ReadReceipt( - "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + "room_id", + "m.read", + "user_id", + ["event_id"], + thread_id=None, + data={"ts": 1234}, ) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) @@ -121,7 +131,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): # send the second RR receipt = ReadReceipt( - "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} + "room_id", + "m.read", + "user_id", + ["other_id"], + thread_id=None, + data={"ts": 1234}, ) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index b17af2725b..af24c4984d 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -447,6 +447,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): receipt_type="m.read", user_id=self.local_user, event_ids=[f"$eventid_{i}"], + thread_id=None, data={}, ) ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 49a21e2e85..efd92793c0 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -171,7 +171,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if send_receipt: self.get_success( self.master_store.insert_receipt( - ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {} + ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {} ) ) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index eb00117845..ede6d0c118 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -33,7 +33,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # tell the master to send a new receipt self.get_success( self.hs.get_datastores().main.insert_receipt( - "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} + "!room:blue", + "m.read", + USER_ID, + ["$event:blue"], + thread_id=None, + data={"a": 1}, ) ) self.replicate() @@ -48,6 +53,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) self.assertEqual("$event:blue", row.event_id) + self.assertIsNone(row.thread_id) self.assertEqual({"a": 1}, row.data) # Now let's disconnect and insert some data. @@ -57,7 +63,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.get_success( self.hs.get_datastores().main.insert_receipt( - "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} + "!room2:blue", + "m.read", + USER_ID, + ["$event2:foo"], + thread_id=None, + data={"a": 2}, ) ) self.replicate() diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index fc43d7edd1..08c74b93e3 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -106,6 +106,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): "m.read", user_id=user_id, event_ids=[event_id], + thread_id=None, data={}, ) ) diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index c89bfff241..9459ee1705 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -131,13 +131,18 @@ class ReceiptTestCase(HomeserverTestCase): # Send public read receipt for the first event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {} ) ) # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.room_id1, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event1_2_id], + None, + {}, ) ) @@ -164,7 +169,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test receipt updating self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {} ) ) res = self.get_success( @@ -180,7 +185,12 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.room_id2, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event2_1_id], + None, + {}, ) ) res = self.get_success( @@ -202,13 +212,18 @@ class ReceiptTestCase(HomeserverTestCase): # Send public read receipt for the first event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {} ) ) # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.room_id1, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event1_2_id], + None, + {}, ) ) @@ -241,7 +256,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test receipt updating self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {} ) ) res = self.get_success( @@ -259,7 +274,12 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.room_id2, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event2_1_id], + None, + {}, ) ) res = self.get_success( -- cgit 1.5.1 From ac1a31740b6d0dfda4d57a25762aaddfde981caf Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 23 Sep 2022 14:01:29 -0500 Subject: Only try to backfill event if we haven't tried before recently (#13635) Only try to backfill event if we haven't tried before recently (exponential backoff). No need to keep trying the same backfill point that fails over and over. Fix https://github.com/matrix-org/synapse/issues/13622 Fix https://github.com/matrix-org/synapse/issues/8451 Follow-up to https://github.com/matrix-org/synapse/pull/13589 Part of https://github.com/matrix-org/synapse/issues/13356 --- changelog.d/13635.feature | 1 + synapse/handlers/federation.py | 4 +- synapse/storage/databases/main/event_federation.py | 188 ++++++-- tests/storage/test_event_federation.py | 481 ++++++++++++++++++++- 4 files changed, 626 insertions(+), 48 deletions(-) create mode 100644 changelog.d/13635.feature (limited to 'tests/storage') diff --git a/changelog.d/13635.feature b/changelog.d/13635.feature new file mode 100644 index 0000000000..d86bf7ed80 --- /dev/null +++ b/changelog.d/13635.feature @@ -0,0 +1 @@ +Exponentially backoff from backfilling the same event over and over. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 583d5ecd77..e1a4265a64 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -226,9 +226,7 @@ class FederationHandler: """ backwards_extremities = [ _BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY) - for event_id, depth in await self.store.get_oldest_event_ids_with_depth_in_room( - room_id - ) + for event_id, depth in await self.store.get_backfill_points_in_room(room_id) ] insertion_events_to_be_backfilled: List[_BackfillPoint] = [] diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ef477978ed..3251fca6fb 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import datetime import itertools import logging from queue import Empty, PriorityQueue @@ -43,7 +44,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore -from synapse.storage.engines import PostgresEngine +from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -72,6 +73,13 @@ pdus_pruned_from_federation_queue = Counter( logger = logging.getLogger(__name__) +BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS: int = int( + datetime.timedelta(days=7).total_seconds() +) +BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS: int = int( + datetime.timedelta(hours=1).total_seconds() +) + # All the info we need while iterating the DAG while backfilling @attr.s(frozen=True, slots=True, auto_attribs=True) @@ -715,96 +723,189 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas @trace @tag_args - async def get_oldest_event_ids_with_depth_in_room( - self, room_id: str + async def get_backfill_points_in_room( + self, + room_id: str, ) -> List[Tuple[str, int]]: - """Gets the oldest events(backwards extremities) in the room along with the - aproximate depth. - - We use this function so that we can compare and see if someones current - depth at their current scrollback is within pagination range of the - event extremeties. If the current depth is close to the depth of given - oldest event, we can trigger a backfill. + """ + Gets the oldest events(backwards extremities) in the room along with the + approximate depth. Sorted by depth, highest to lowest (descending). Args: room_id: Room where we want to find the oldest events Returns: - List of (event_id, depth) tuples + List of (event_id, depth) tuples. Sorted by depth, highest to lowest + (descending) """ - def get_oldest_event_ids_with_depth_in_room_txn( + def get_backfill_points_in_room_txn( txn: LoggingTransaction, room_id: str ) -> List[Tuple[str, int]]: - # Assemble a dictionary with event_id -> depth for the oldest events + # Assemble a tuple lookup of event_id -> depth for the oldest events # we know of in the room. Backwards extremeties are the oldest # events we know of in the room but we only know of them because - # some other event referenced them by prev_event and aren't peristed - # in our database yet (meaning we don't know their depth - # specifically). So we need to look for the aproximate depth from + # some other event referenced them by prev_event and aren't + # persisted in our database yet (meaning we don't know their depth + # specifically). So we need to look for the approximate depth from # the events connected to the current backwards extremeties. sql = """ - SELECT b.event_id, MAX(e.depth) FROM events as e + SELECT backward_extrem.event_id, event.depth FROM events AS event /** * Get the edge connections from the event_edges table * so we can see whether this event's prev_events points * to a backward extremity in the next join. */ - INNER JOIN event_edges as g - ON g.event_id = e.event_id + INNER JOIN event_edges AS edge + ON edge.event_id = event.event_id /** * We find the "oldest" events in the room by looking for * events connected to backwards extremeties (oldest events * in the room that we know of so far). */ - INNER JOIN event_backward_extremities as b - ON g.prev_event_id = b.event_id - WHERE b.room_id = ? AND g.is_state is ? - GROUP BY b.event_id + INNER JOIN event_backward_extremities AS backward_extrem + ON edge.prev_event_id = backward_extrem.event_id + /** + * We use this info to make sure we don't retry to use a backfill point + * if we've already attempted to backfill from it recently. + */ + LEFT JOIN event_failed_pull_attempts AS failed_backfill_attempt_info + ON + failed_backfill_attempt_info.room_id = backward_extrem.room_id + AND failed_backfill_attempt_info.event_id = backward_extrem.event_id + WHERE + backward_extrem.room_id = ? + /* We only care about non-state edges because we used to use + * `event_edges` for two different sorts of "edges" (the current + * event DAG, but also a link to the previous state, for state + * events). These legacy state event edges can be distinguished by + * `is_state` and are removed from the codebase and schema but + * because the schema change is in a background update, it's not + * necessarily safe to assume that it will have been completed. + */ + AND edge.is_state is ? /* False */ + /** + * Exponential back-off (up to the upper bound) so we don't retry the + * same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. + * + * We use `1 << n` as a power of 2 equivalent for compatibility + * with older SQLites. The left shift equivalent only works with + * powers of 2 because left shift is a binary operation (base-2). + * Otherwise, we would use `power(2, n)` or the power operator, `2^n`. + */ + AND ( + failed_backfill_attempt_info.event_id IS NULL + OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */) + ) + /** + * Sort from highest to the lowest depth. Then tie-break on + * alphabetical order of the event_ids so we get a consistent + * ordering which is nice when asserting things in tests. + */ + ORDER BY event.depth DESC, backward_extrem.event_id DESC """ - txn.execute(sql, (room_id, False)) + if isinstance(self.database_engine, PostgresEngine): + least_function = "least" + elif isinstance(self.database_engine, Sqlite3Engine): + least_function = "min" + else: + raise RuntimeError("Unknown database engine") + + txn.execute( + sql % (least_function,), + ( + room_id, + False, + self._clock.time_msec(), + 1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS, + 1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS, + ), + ) return cast(List[Tuple[str, int]], txn.fetchall()) return await self.db_pool.runInteraction( - "get_oldest_event_ids_with_depth_in_room", - get_oldest_event_ids_with_depth_in_room_txn, + "get_backfill_points_in_room", + get_backfill_points_in_room_txn, room_id, ) @trace async def get_insertion_event_backward_extremities_in_room( - self, room_id: str + self, + room_id: str, ) -> List[Tuple[str, int]]: - """Get the insertion events we know about that we haven't backfilled yet. - - We use this function so that we can compare and see if someones current - depth at their current scrollback is within pagination range of the - insertion event. If the current depth is close to the depth of given - insertion event, we can trigger a backfill. + """ + Get the insertion events we know about that we haven't backfilled yet + along with the approximate depth. Sorted by depth, highest to lowest + (descending). Args: room_id: Room where we want to find the oldest events Returns: - List of (event_id, depth) tuples + List of (event_id, depth) tuples. Sorted by depth, highest to lowest + (descending) """ def get_insertion_event_backward_extremities_in_room_txn( txn: LoggingTransaction, room_id: str ) -> List[Tuple[str, int]]: sql = """ - SELECT b.event_id, MAX(e.depth) FROM insertion_events as i + SELECT + insertion_event_extremity.event_id, event.depth /* We only want insertion events that are also marked as backwards extremities */ - INNER JOIN insertion_event_extremities as b USING (event_id) + FROM insertion_event_extremities AS insertion_event_extremity /* Get the depth of the insertion event from the events table */ - INNER JOIN events AS e USING (event_id) - WHERE b.room_id = ? - GROUP BY b.event_id + INNER JOIN events AS event USING (event_id) + /** + * We use this info to make sure we don't retry to use a backfill point + * if we've already attempted to backfill from it recently. + */ + LEFT JOIN event_failed_pull_attempts AS failed_backfill_attempt_info + ON + failed_backfill_attempt_info.room_id = insertion_event_extremity.room_id + AND failed_backfill_attempt_info.event_id = insertion_event_extremity.event_id + WHERE + insertion_event_extremity.room_id = ? + /** + * Exponential back-off (up to the upper bound) so we don't retry the + * same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc + * + * We use `1 << n` as a power of 2 equivalent for compatibility + * with older SQLites. The left shift equivalent only works with + * powers of 2 because left shift is a binary operation (base-2). + * Otherwise, we would use `power(2, n)` or the power operator, `2^n`. + */ + AND ( + failed_backfill_attempt_info.event_id IS NULL + OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */) + ) + /** + * Sort from highest to the lowest depth. Then tie-break on + * alphabetical order of the event_ids so we get a consistent + * ordering which is nice when asserting things in tests. + */ + ORDER BY event.depth DESC, insertion_event_extremity.event_id DESC """ - txn.execute(sql, (room_id,)) + if isinstance(self.database_engine, PostgresEngine): + least_function = "least" + elif isinstance(self.database_engine, Sqlite3Engine): + least_function = "min" + else: + raise RuntimeError("Unknown database engine") + + txn.execute( + sql % (least_function,), + ( + room_id, + self._clock.time_msec(), + 1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS, + 1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS, + ), + ) return cast(List[Tuple[str, int]], txn.fetchall()) return await self.db_pool.runInteraction( @@ -1539,7 +1640,12 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas self, room_id: str, ) -> Optional[Tuple[str, str]]: - """Get the next event ID in the staging area for the given room.""" + """ + Get the next event ID in the staging area for the given room. + + Returns: + Tuple of the `origin` and `event_id` + """ def _get_next_staged_event_id_for_room_txn( txn: LoggingTransaction, diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index a6679e1312..85739c464e 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -12,25 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +import datetime +from typing import Dict, List, Tuple, Union import attr from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventTypes from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, RoomVersion, ) from synapse.events import _EventInternalMetadata -from synapse.util import json_encoder +from synapse.server import HomeServer +from synapse.storage.database import LoggingTransaction +from synapse.types import JsonDict +from synapse.util import Clock, json_encoder import tests.unittest import tests.utils +@attr.s(auto_attribs=True, frozen=True, slots=True) +class _BackfillSetupInfo: + room_id: str + depth_map: Dict[str, int] + + class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main def test_get_prev_events_for_room(self): @@ -571,11 +584,471 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) self.assertEqual(count, 1) - _, event_id = self.get_success( + next_staged_event_info = self.get_success( self.store.get_next_staged_event_id_for_room(room_id) ) + assert next_staged_event_info + _, event_id = next_staged_event_info self.assertEqual(event_id, "$fake_event_id_500") + def _setup_room_for_backfill_tests(self) -> _BackfillSetupInfo: + """ + Sets up a room with various events and backward extremities to test + backfill functions against. + + Returns: + _BackfillSetupInfo including the `room_id` to test against and + `depth_map` of events in the room + """ + room_id = "!backfill-room-test:some-host" + + # The silly graph we use to test grabbing backward extremities, + # where the top is the oldest events. + # 1 (oldest) + # | + # 2 ⹁ + # | \ + # | [b1, b2, b3] + # | | + # | A + # | / + # 3 { + # | \ + # | [b4, b5, b6] + # | | + # | B + # | / + # 4 ´ + # | + # 5 (newest) + + event_graph: Dict[str, List[str]] = { + "1": [], + "2": ["1"], + "3": ["2", "A"], + "4": ["3", "B"], + "5": ["4"], + "A": ["b1", "b2", "b3"], + "b1": ["2"], + "b2": ["2"], + "b3": ["2"], + "B": ["b4", "b5", "b6"], + "b4": ["3"], + "b5": ["3"], + "b6": ["3"], + } + + depth_map: Dict[str, int] = { + "1": 1, + "2": 2, + "b1": 3, + "b2": 3, + "b3": 3, + "A": 4, + "3": 5, + "b4": 6, + "b5": 6, + "b6": 6, + "B": 7, + "4": 8, + "5": 9, + } + + # The events we have persisted on our server. + # The rest are events in the room but not backfilled tet. + our_server_events = {"5", "4", "B", "3", "A"} + + complete_event_dict_map: Dict[str, JsonDict] = {} + stream_ordering = 0 + for (event_id, prev_event_ids) in event_graph.items(): + depth = depth_map[event_id] + + complete_event_dict_map[event_id] = { + "event_id": event_id, + "type": "test_regular_type", + "room_id": room_id, + "sender": "@sender", + "prev_event_ids": prev_event_ids, + "auth_event_ids": [], + "origin_server_ts": stream_ordering, + "depth": depth, + "stream_ordering": stream_ordering, + "content": {"body": "event" + event_id}, + } + + stream_ordering += 1 + + def populate_db(txn: LoggingTransaction): + # Insert the room to satisfy the foreign key constraint of + # `event_failed_pull_attempts` + self.store.db_pool.simple_insert_txn( + txn, + "rooms", + { + "room_id": room_id, + "creator": "room_creator_user_id", + "is_public": True, + "room_version": "6", + }, + ) + + # Insert our server events + for event_id in our_server_events: + event_dict = complete_event_dict_map[event_id] + + self.store.db_pool.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_dict.get("event_id"), + "type": event_dict.get("type"), + "room_id": event_dict.get("room_id"), + "depth": event_dict.get("depth"), + "topological_ordering": event_dict.get("depth"), + "stream_ordering": event_dict.get("stream_ordering"), + "processed": True, + "outlier": False, + }, + ) + + # Insert the event edges + for event_id in our_server_events: + for prev_event_id in event_graph[event_id]: + self.store.db_pool.simple_insert_txn( + txn, + table="event_edges", + values={ + "event_id": event_id, + "prev_event_id": prev_event_id, + "room_id": room_id, + }, + ) + + # Insert the backward extremities + prev_events_of_our_events = { + prev_event_id + for our_server_event in our_server_events + for prev_event_id in complete_event_dict_map[our_server_event][ + "prev_event_ids" + ] + } + backward_extremities = prev_events_of_our_events - our_server_events + for backward_extremity in backward_extremities: + self.store.db_pool.simple_insert_txn( + txn, + table="event_backward_extremities", + values={ + "event_id": backward_extremity, + "room_id": room_id, + }, + ) + + self.get_success( + self.store.db_pool.runInteraction( + "_setup_room_for_backfill_tests_populate_db", + populate_db, + ) + ) + + return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map) + + def test_get_backfill_points_in_room(self): + """ + Test to make sure we get some backfill points + """ + setup_info = self._setup_room_for_backfill_tests() + room_id = setup_info.room_id + + backfill_points = self.get_success( + self.store.get_backfill_points_in_room(room_id) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + self.assertListEqual( + backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"] + ) + + def test_get_backfill_points_in_room_excludes_events_we_have_attempted( + self, + ): + """ + Test to make sure that events we have attempted to backfill (and within + backoff timeout duration) do not show up as an event to backfill again. + """ + setup_info = self._setup_room_for_backfill_tests() + room_id = setup_info.room_id + + # Record some attempts to backfill these events which will make + # `get_backfill_points_in_room` exclude them because we + # haven't passed the backoff interval. + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b5", "fake cause") + ) + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b4", "fake cause") + ) + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b3", "fake cause") + ) + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b2", "fake cause") + ) + + # No time has passed since we attempted to backfill ^ + + backfill_points = self.get_success( + self.store.get_backfill_points_in_room(room_id) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + # Only the backfill points that we didn't record earlier exist here. + self.assertListEqual(backfill_event_ids, ["b6", "2", "b1"]) + + def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration( + self, + ): + """ + Test to make sure after we fake attempt to backfill event "b3" many times, + we can see retry and see the "b3" again after the backoff timeout duration + has exceeded. + """ + setup_info = self._setup_room_for_backfill_tests() + room_id = setup_info.room_id + + # Record some attempts to backfill these events which will make + # `get_backfill_points_in_room` exclude them because we + # haven't passed the backoff interval. + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b3", "fake cause") + ) + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause") + ) + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause") + ) + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause") + ) + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause") + ) + + # Now advance time by 2 hours and we should only be able to see "b3" + # because we have waited long enough for the single attempt (2^1 hours) + # but we still shouldn't see "b1" because we haven't waited long enough + # for this many attempts. We didn't do anything to "b2" so it should be + # visible regardless. + self.reactor.advance(datetime.timedelta(hours=2).total_seconds()) + + # Make sure that "b1" is not in the list because we've + # already attempted many times + backfill_points = self.get_success( + self.store.get_backfill_points_in_room(room_id) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + self.assertListEqual(backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2"]) + + # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and + # see if we can now backfill it + self.reactor.advance(datetime.timedelta(hours=20).total_seconds()) + + # Try again after we advanced enough time and we should see "b3" again + backfill_points = self.get_success( + self.store.get_backfill_points_in_room(room_id) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + self.assertListEqual( + backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"] + ) + + def _setup_room_for_insertion_backfill_tests(self) -> _BackfillSetupInfo: + """ + Sets up a room with various insertion event backward extremities to test + backfill functions against. + + Returns: + _BackfillSetupInfo including the `room_id` to test against and + `depth_map` of events in the room + """ + room_id = "!backfill-room-test:some-host" + + depth_map: Dict[str, int] = { + "1": 1, + "2": 2, + "insertion_eventA": 3, + "3": 4, + "insertion_eventB": 5, + "4": 6, + "5": 7, + } + + def populate_db(txn: LoggingTransaction): + # Insert the room to satisfy the foreign key constraint of + # `event_failed_pull_attempts` + self.store.db_pool.simple_insert_txn( + txn, + "rooms", + { + "room_id": room_id, + "creator": "room_creator_user_id", + "is_public": True, + "room_version": "6", + }, + ) + + # Insert our server events + stream_ordering = 0 + for event_id, depth in depth_map.items(): + self.store.db_pool.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "type": EventTypes.MSC2716_INSERTION + if event_id.startswith("insertion_event") + else "test_regular_type", + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "stream_ordering": stream_ordering, + "processed": True, + "outlier": False, + }, + ) + + if event_id.startswith("insertion_event"): + self.store.db_pool.simple_insert_txn( + txn, + table="insertion_event_extremities", + values={ + "event_id": event_id, + "room_id": room_id, + }, + ) + + stream_ordering += 1 + + self.get_success( + self.store.db_pool.runInteraction( + "_setup_room_for_insertion_backfill_tests_populate_db", + populate_db, + ) + ) + + return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map) + + def test_get_insertion_event_backward_extremities_in_room(self): + """ + Test to make sure insertion event backward extremities are returned. + """ + setup_info = self._setup_room_for_insertion_backfill_tests() + room_id = setup_info.room_id + + backfill_points = self.get_success( + self.store.get_insertion_event_backward_extremities_in_room(room_id) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + self.assertListEqual( + backfill_event_ids, ["insertion_eventB", "insertion_eventA"] + ) + + def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted( + self, + ): + """ + Test to make sure that insertion events we have attempted to backfill + (and within backoff timeout duration) do not show up as an event to + backfill again. + """ + setup_info = self._setup_room_for_insertion_backfill_tests() + room_id = setup_info.room_id + + # Record some attempts to backfill these events which will make + # `get_insertion_event_backward_extremities_in_room` exclude them + # because we haven't passed the backoff interval. + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "insertion_eventA", "fake cause" + ) + ) + + # No time has passed since we attempted to backfill ^ + + backfill_points = self.get_success( + self.store.get_insertion_event_backward_extremities_in_room(room_id) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + # Only the backfill points that we didn't record earlier exist here. + self.assertListEqual(backfill_event_ids, ["insertion_eventB"]) + + def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration( + self, + ): + """ + Test to make sure after we fake attempt to backfill event + "insertion_eventA" many times, we can see retry and see the + "insertion_eventA" again after the backoff timeout duration has + exceeded. + """ + setup_info = self._setup_room_for_insertion_backfill_tests() + room_id = setup_info.room_id + + # Record some attempts to backfill these events which will make + # `get_backfill_points_in_room` exclude them because we + # haven't passed the backoff interval. + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "insertion_eventB", "fake cause" + ) + ) + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "insertion_eventA", "fake cause" + ) + ) + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "insertion_eventA", "fake cause" + ) + ) + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "insertion_eventA", "fake cause" + ) + ) + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "insertion_eventA", "fake cause" + ) + ) + + # Now advance time by 2 hours and we should only be able to see + # "insertion_eventB" because we have waited long enough for the single + # attempt (2^1 hours) but we still shouldn't see "insertion_eventA" + # because we haven't waited long enough for this many attempts. + self.reactor.advance(datetime.timedelta(hours=2).total_seconds()) + + # Make sure that "insertion_eventA" is not in the list because we've + # already attempted many times + backfill_points = self.get_success( + self.store.get_insertion_event_backward_extremities_in_room(room_id) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + self.assertListEqual(backfill_event_ids, ["insertion_eventB"]) + + # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and + # see if we can now backfill it + self.reactor.advance(datetime.timedelta(hours=20).total_seconds()) + + # Try at "insertion_eventA" again after we advanced enough time and we + # should see "insertion_eventA" again + backfill_points = self.get_success( + self.store.get_insertion_event_backward_extremities_in_room(room_id) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + self.assertListEqual( + backfill_event_ids, ["insertion_eventB", "insertion_eventA"] + ) + @attr.s class FakeEvent: -- cgit 1.5.1 From 2fae1a3f7862bf38cd0b52dfd3ea3ae76794d2b7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 26 Sep 2022 14:28:12 -0400 Subject: Improve tests for get_unread_push_actions_for_user_in_range_*. (#13893) * Adds a docstring. * Reduces a small amount of duplicated code. * Improves tests. --- changelog.d/13893.feature | 1 + .../storage/databases/main/event_push_actions.py | 38 ++++++---- tests/storage/test_event_push_actions.py | 88 ++++++++++++++++++---- 3 files changed, 97 insertions(+), 30 deletions(-) create mode 100644 changelog.d/13893.feature (limited to 'tests/storage') diff --git a/changelog.d/13893.feature b/changelog.d/13893.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13893.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 6b8668d2dc..f4cdc2e399 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -559,7 +559,18 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def _get_receipts_by_room_txn( self, txn: LoggingTransaction, user_id: str - ) -> List[Tuple[str, int]]: + ) -> Dict[str, int]: + """ + Generate a map of room ID to the latest stream ordering that has been + read by the given user. + + Args: + txn: + user_id: The user to fetch receipts for. + + Returns: + A map of room ID to stream ordering for all rooms the user has a receipt in. + """ receipt_types_clause, args = make_in_list_sql_clause( self.database_engine, "receipt_type", @@ -580,7 +591,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas args.extend((user_id,)) txn.execute(sql, args) - return cast(List[Tuple[str, int]], txn.fetchall()) + return { + room_id: latest_stream_ordering + for room_id, latest_stream_ordering in txn.fetchall() + } async def get_unread_push_actions_for_user_in_range_for_http( self, @@ -605,12 +619,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will have between 0~limit entries. """ - receipts_by_room = dict( - await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_http_receipts", - self._get_receipts_by_room_txn, - user_id=user_id, - ), + receipts_by_room = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_http_receipts", + self._get_receipts_by_room_txn, + user_id=user_id, ) def get_push_actions_txn( @@ -679,12 +691,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will have between 0~limit entries. """ - receipts_by_room = dict( - await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_email_receipts", - self._get_receipts_by_room_txn, - user_id=user_id, - ), + receipts_by_room = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_email_receipts", + self._get_receipts_by_room_txn, + user_id=user_id, ) def get_push_actions_txn( diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 08c74b93e3..473c965e19 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple + from twisted.test.proto_helpers import MemoryReactor from synapse.rest import admin @@ -22,8 +24,6 @@ from synapse.util import Clock from tests.unittest import HomeserverTestCase -USER_ID = "@user:example.com" - class EventPushActionsStoreTestCase(HomeserverTestCase): servlets = [ @@ -38,21 +38,13 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): assert persist_events_store is not None self.persist_events_store = persist_events_store - def test_get_unread_push_actions_for_user_in_range_for_http(self) -> None: - self.get_success( - self.store.get_unread_push_actions_for_user_in_range_for_http( - USER_ID, 0, 1000, 20 - ) - ) + def _create_users_and_room(self) -> Tuple[str, str, str, str, str]: + """ + Creates two users and a shared room. - def test_get_unread_push_actions_for_user_in_range_for_email(self) -> None: - self.get_success( - self.store.get_unread_push_actions_for_user_in_range_for_email( - USER_ID, 0, 1000, 20 - ) - ) - - def test_count_aggregation(self) -> None: + Returns: + Tuple of (user 1 ID, user 1 token, user 2 ID, user 2 token, room ID). + """ # Create a user to receive notifications and send receipts. user_id = self.register_user("user1235", "pass") token = self.login("user1235", "pass") @@ -65,6 +57,70 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): room_id = self.helper.create_room_as(user_id, tok=token) self.helper.join(room_id, other_id, tok=other_token) + return user_id, token, other_id, other_token, room_id + + def test_get_unread_push_actions_for_user_in_range(self) -> None: + """Test getting unread push actions for HTTP and email pushers.""" + user_id, token, _, other_token, room_id = self._create_users_and_room() + + # Create two events, one of which is a highlight. + self.helper.send_event( + room_id, + type="m.room.message", + content={"msgtype": "m.text", "body": "msg"}, + tok=other_token, + ) + event_id = self.helper.send_event( + room_id, + type="m.room.message", + content={"msgtype": "m.text", "body": user_id}, + tok=other_token, + )["event_id"] + + # Fetch unread actions for HTTP pushers. + http_actions = self.get_success( + self.store.get_unread_push_actions_for_user_in_range_for_http( + user_id, 0, 1000, 20 + ) + ) + self.assertEqual(2, len(http_actions)) + + # Fetch unread actions for email pushers. + email_actions = self.get_success( + self.store.get_unread_push_actions_for_user_in_range_for_email( + user_id, 0, 1000, 20 + ) + ) + self.assertEqual(2, len(email_actions)) + + # Send a receipt, which should clear any actions. + self.get_success( + self.store.insert_receipt( + room_id, + "m.read", + user_id=user_id, + event_ids=[event_id], + thread_id=None, + data={}, + ) + ) + http_actions = self.get_success( + self.store.get_unread_push_actions_for_user_in_range_for_http( + user_id, 0, 1000, 20 + ) + ) + self.assertEqual([], http_actions) + email_actions = self.get_success( + self.store.get_unread_push_actions_for_user_in_range_for_email( + user_id, 0, 1000, 20 + ) + ) + self.assertEqual([], email_actions) + + def test_count_aggregation(self) -> None: + # Create a user to receive notifications and send receipts. + user_id, token, _, other_token, room_id = self._create_users_and_room() + last_event_id: str def _assert_counts(noitf_count: int, highlight_count: int) -> None: -- cgit 1.5.1 From 29269d9d3f3419a3d92cdd80dae4a37e2d99a395 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 27 Sep 2022 15:55:43 -0500 Subject: Fix `have_seen_event` cache not being invalidated (#13863) Fix https://github.com/matrix-org/synapse/issues/13856 Fix https://github.com/matrix-org/synapse/issues/13865 > Discovered while trying to make Synapse fast enough for [this MSC2716 test for importing many batches](https://github.com/matrix-org/complement/pull/214#discussion_r741678240). As an example, disabling the `have_seen_event` cache saves 10 seconds for each `/messages` request in that MSC2716 Complement test because we're not making as many federation requests for `/state` (speeding up `have_seen_event` itself is related to https://github.com/matrix-org/synapse/issues/13625) > > But this will also make `/messages` faster in general so we can include it in the [faster `/messages` milestone](https://github.com/matrix-org/synapse/milestone/11). > > *-- https://github.com/matrix-org/synapse/issues/13856* ### The problem `_invalidate_caches_for_event` doesn't run in monolith mode which means we never even tried to clear the `have_seen_event` and other caches. And even in worker mode, it only runs on the workers, not the master (AFAICT). Additionally there was bug with the key being wrong so `_invalidate_caches_for_event` never invalidates the `have_seen_event` cache even when it does run. Because we were using the `@cachedList` wrong, it was putting items in the cache under keys like `((room_id, event_id),)` with a `set` in a `set` (ex. `(('!TnCIJPKzdQdUlIyXdQ:test', '$Iu0eqEBN7qcyF1S9B3oNB3I91v2o5YOgRNPwi_78s-k'),)`) and we we're trying to invalidate with just `(room_id, event_id)` which did nothing. --- changelog.d/13863.bugfix | 1 + synapse/storage/databases/main/events_worker.py | 40 +++--- synapse/util/caches/descriptors.py | 6 + tests/storage/databases/main/test_events_worker.py | 152 ++++++++++++++------- tests/util/caches/test_descriptors.py | 33 ++++- 5 files changed, 165 insertions(+), 67 deletions(-) create mode 100644 changelog.d/13863.bugfix (limited to 'tests/storage') diff --git a/changelog.d/13863.bugfix b/changelog.d/13863.bugfix new file mode 100644 index 0000000000..74264a4fab --- /dev/null +++ b/changelog.d/13863.bugfix @@ -0,0 +1 @@ +Fix `have_seen_event` cache not being invalidated after we persist an event which causes inefficiency effects like extra `/state` federation calls. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 52914febf9..7cdc9fe98f 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1474,32 +1474,38 @@ class EventsWorkerStore(SQLBaseStore): # the batches as big as possible. results: Set[str] = set() - for chunk in batch_iter(event_ids, 500): - r = await self._have_seen_events_dict( - [(room_id, event_id) for event_id in chunk] + for event_ids_chunk in batch_iter(event_ids, 500): + events_seen_dict = await self._have_seen_events_dict( + room_id, event_ids_chunk + ) + results.update( + eid for (eid, have_event) in events_seen_dict.items() if have_event ) - results.update(eid for ((_rid, eid), have_event) in r.items() if have_event) return results - @cachedList(cached_method_name="have_seen_event", list_name="keys") + @cachedList(cached_method_name="have_seen_event", list_name="event_ids") async def _have_seen_events_dict( - self, keys: Collection[Tuple[str, str]] - ) -> Dict[Tuple[str, str], bool]: + self, + room_id: str, + event_ids: Collection[str], + ) -> Dict[str, bool]: """Helper for have_seen_events Returns: - a dict {(room_id, event_id)-> bool} + a dict {event_id -> bool} """ # if the event cache contains the event, obviously we've seen it. cache_results = { - (rid, eid) - for (rid, eid) in keys - if await self._get_event_cache.contains((eid,)) + event_id + for event_id in event_ids + if await self._get_event_cache.contains((event_id,)) } results = dict.fromkeys(cache_results, True) - remaining = [k for k in keys if k not in cache_results] + remaining = [ + event_id for event_id in event_ids if event_id not in cache_results + ] if not remaining: return results @@ -1511,23 +1517,21 @@ class EventsWorkerStore(SQLBaseStore): sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining] + txn.database_engine, "e.event_id", remaining ) txn.execute(sql + clause, args) found_events = {eid for eid, in txn} # ... and then we can update the results for each key - results.update( - {(rid, eid): (eid in found_events) for (rid, eid) in remaining} - ) + results.update({eid: (eid in found_events) for eid in remaining}) await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn) return results @cached(max_entries=100000, tree=True) async def have_seen_event(self, room_id: str, event_id: str) -> bool: - res = await self._have_seen_events_dict(((room_id, event_id),)) - return res[(room_id, event_id)] + res = await self._have_seen_events_dict(room_id, [event_id]) + return res[event_id] def _get_current_state_event_counts_txn( self, txn: LoggingTransaction, room_id: str diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 3909f1caea..0391966462 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -431,6 +431,12 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): cache: DeferredCache[CacheKey, Any] = cached_method.cache num_args = cached_method.num_args + if num_args != self.num_args: + raise Exception( + "Number of args (%s) does not match underlying cache_method_name=%s (%s)." + % (self.num_args, self.cached_method_name, num_args) + ) + @functools.wraps(self.orig) def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]": # If we're passed a cache_context then we'll want to call its diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 67401272ac..32a798d74b 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -35,66 +35,45 @@ from synapse.util import Clock from synapse.util.async_helpers import yieldable_gather_results from tests import unittest +from tests.test_utils.event_injection import create_event, inject_event class HaveSeenEventsTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + def prepare(self, reactor, clock, hs): + self.hs = hs self.store: EventsWorkerStore = hs.get_datastores().main - # insert some test data - for rid in ("room1", "room2"): - self.get_success( - self.store.db_pool.simple_insert( - "rooms", - {"room_id": rid, "room_version": 4}, - ) - ) + self.user = self.register_user("user", "pass") + self.token = self.login(self.user, "pass") + self.room_id = self.helper.create_room_as(self.user, tok=self.token) self.event_ids: List[str] = [] - for idx, rid in enumerate( - ( - "room1", - "room1", - "room1", - "room2", - ) - ): - event_json = {"type": f"test {idx}", "room_id": rid} - event = make_event_from_dict(event_json, room_version=RoomVersions.V4) - event_id = event.event_id - - self.get_success( - self.store.db_pool.simple_insert( - "events", - { - "event_id": event_id, - "room_id": rid, - "topological_ordering": idx, - "stream_ordering": idx, - "type": event.type, - "processed": True, - "outlier": False, - }, + for i in range(3): + event = self.get_success( + inject_event( + hs, + room_version=RoomVersions.V7.identifier, + room_id=self.room_id, + sender=self.user, + type="test_event_type", + content={"body": f"foobarbaz{i}"}, ) ) - self.get_success( - self.store.db_pool.simple_insert( - "event_json", - { - "event_id": event_id, - "room_id": rid, - "json": json.dumps(event_json), - "internal_metadata": "{}", - "format_version": 3, - }, - ) - ) - self.event_ids.append(event_id) + + self.event_ids.append(event.event_id) def test_simple(self): with LoggingContext(name="test") as ctx: res = self.get_success( - self.store.have_seen_events("room1", [self.event_ids[0], "event19"]) + self.store.have_seen_events( + self.room_id, [self.event_ids[0], "eventdoesnotexist"] + ) ) self.assertEqual(res, {self.event_ids[0]}) @@ -104,7 +83,9 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # a second lookup of the same events should cause no queries with LoggingContext(name="test") as ctx: res = self.get_success( - self.store.have_seen_events("room1", [self.event_ids[0], "event19"]) + self.store.have_seen_events( + self.room_id, [self.event_ids[0], "eventdoesnotexist"] + ) ) self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) @@ -116,11 +97,86 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # looking it up should now cause no db hits with LoggingContext(name="test") as ctx: res = self.get_success( - self.store.have_seen_events("room1", [self.event_ids[0]]) + self.store.have_seen_events(self.room_id, [self.event_ids[0]]) ) self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) + def test_persisting_event_invalidates_cache(self): + """ + Test to make sure that the `have_seen_event` cache + is invalidated after we persist an event and returns + the updated value. + """ + event, event_context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + sender=self.user, + type="test_event_type", + content={"body": "garply"}, + ) + ) + + with LoggingContext(name="test") as ctx: + # First, check `have_seen_event` for an event we have not seen yet + # to prime the cache with a `false` value. + res = self.get_success( + self.store.have_seen_events(event.room_id, [event.event_id]) + ) + self.assertEqual(res, set()) + + # That should result in a single db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + + # Persist the event which should invalidate or prefill the + # `have_seen_event` cache so we don't return stale values. + persistence = self.hs.get_storage_controllers().persistence + self.get_success( + persistence.persist_event( + event, + event_context, + ) + ) + + with LoggingContext(name="test") as ctx: + # Check `have_seen_event` again and we should see the updated fact + # that we have now seen the event after persisting it. + res = self.get_success( + self.store.have_seen_events(event.room_id, [event.event_id]) + ) + self.assertEqual(res, {event.event_id}) + + # That should result in a single db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + + def test_invalidate_cache_by_room_id(self): + """ + Test to make sure that all events associated with the given `(room_id,)` + are invalidated in the `have_seen_event` cache. + """ + with LoggingContext(name="test") as ctx: + # Prime the cache with some values + res = self.get_success( + self.store.have_seen_events(self.room_id, self.event_ids) + ) + self.assertEqual(res, set(self.event_ids)) + + # That should result in a single db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + + # Clear the cache with any events associated with the `room_id` + self.store.have_seen_event.invalidate((self.room_id,)) + + with LoggingContext(name="test") as ctx: + res = self.get_success( + self.store.have_seen_events(self.room_id, self.event_ids) + ) + self.assertEqual(res, set(self.event_ids)) + + # Since we cleared the cache, it should result in another db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + class EventCacheTestCase(unittest.HomeserverTestCase): """Test that the various layers of event cache works.""" diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 48e616ac74..90861fe522 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Set +from typing import Iterable, Set, Tuple from unittest import mock from twisted.internet import defer, reactor @@ -1008,3 +1008,34 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj.inner_context_was_finished, "Tried to restart a finished logcontext" ) self.assertEqual(current_context(), SENTINEL_CONTEXT) + + def test_num_args_mismatch(self): + """ + Make sure someone does not accidentally use @cachedList on a method with + a mismatch in the number args to the underlying single cache method. + """ + + class Cls: + @descriptors.cached(tree=True) + def fn(self, room_id, event_id): + pass + + # This is wrong ❌. `@cachedList` expects to be given the same number + # of arguments as the underlying cached function, just with one of + # the arguments being an iterable + @descriptors.cachedList(cached_method_name="fn", list_name="keys") + def list_fn(self, keys: Iterable[Tuple[str, str]]): + pass + + # Corrected syntax ✅ + # + # @cachedList(cached_method_name="fn", list_name="event_ids") + # async def list_fn( + # self, room_id: str, event_ids: Collection[str], + # ) + + obj = Cls() + + # Make sure this raises an error about the arg mismatch + with self.assertRaises(Exception): + obj.list_fn([("foo", "bar")]) -- cgit 1.5.1 From 8ab16a92edd675453c78cfd9974081e374b0f998 Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 28 Sep 2022 03:11:48 -0700 Subject: Persist CreateRoom events to DB in a batch (#13800) --- changelog.d/13800.misc | 1 + synapse/handlers/message.py | 663 +++++++++++++++++--------------- synapse/handlers/room.py | 21 +- synapse/handlers/room_batch.py | 3 +- synapse/handlers/room_member.py | 11 +- synapse/replication/http/__init__.py | 2 + synapse/replication/http/send_event.py | 4 +- synapse/replication/http/send_events.py | 171 ++++++++ tests/handlers/test_message.py | 10 +- tests/handlers/test_register.py | 4 +- tests/storage/test_event_chain.py | 8 +- tests/unittest.py | 4 +- 12 files changed, 563 insertions(+), 339 deletions(-) create mode 100644 changelog.d/13800.misc create mode 100644 synapse/replication/http/send_events.py (limited to 'tests/storage') diff --git a/changelog.d/13800.misc b/changelog.d/13800.misc new file mode 100644 index 0000000000..761adc8b05 --- /dev/null +++ b/changelog.d/13800.misc @@ -0,0 +1 @@ +Speed up creation of DM rooms. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 062f93bc67..00e7645ba5 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -56,11 +56,13 @@ from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet +from synapse.replication.http.send_events import ReplicationSendEventsRestServlet from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import ( MutableStateMap, + PersistedEventPosition, Requester, RoomAlias, StateMap, @@ -493,6 +495,7 @@ class EventCreationHandler: self.membership_types_to_include_profile_data_in.add(Membership.INVITE) self.send_event = ReplicationSendEventRestServlet.make_client(hs) + self.send_events = ReplicationSendEventsRestServlet.make_client(hs) self.request_ratelimiter = hs.get_request_ratelimiter() @@ -1016,8 +1019,7 @@ class EventCreationHandler: ev = await self.handle_new_client_event( requester=requester, - event=event, - context=context, + events_and_context=[(event, context)], ratelimit=ratelimit, ignore_shadow_ban=ignore_shadow_ban, ) @@ -1293,13 +1295,13 @@ class EventCreationHandler: async def handle_new_client_event( self, requester: Requester, - event: EventBase, - context: EventContext, + events_and_context: List[Tuple[EventBase, EventContext]], ratelimit: bool = True, extra_users: Optional[List[UserID]] = None, ignore_shadow_ban: bool = False, ) -> EventBase: - """Processes a new event. + """Processes new events. Please note that if batch persisting events, an error in + handling any one of these events will result in all of the events being dropped. This includes deduplicating, checking auth, persisting, notifying users, sending to remote servers, etc. @@ -1309,8 +1311,7 @@ class EventCreationHandler: Args: requester - event - context + events_and_context: A list of one or more tuples of event, context to be persisted ratelimit extra_users: Any extra users to notify about event @@ -1328,62 +1329,63 @@ class EventCreationHandler: """ extra_users = extra_users or [] - # we don't apply shadow-banning to membership events here. Invites are blocked - # higher up the stack, and we allow shadow-banned users to send join and leave - # events as normal. - if ( - event.type != EventTypes.Member - and not ignore_shadow_ban - and requester.shadow_banned - ): - # We randomly sleep a bit just to annoy the requester. - await self.clock.sleep(random.randint(1, 10)) - raise ShadowBanError() + for event, context in events_and_context: + # we don't apply shadow-banning to membership events here. Invites are blocked + # higher up the stack, and we allow shadow-banned users to send join and leave + # events as normal. + if ( + event.type != EventTypes.Member + and not ignore_shadow_ban + and requester.shadow_banned + ): + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() - if event.is_state(): - prev_event = await self.deduplicate_state_event(event, context) - if prev_event is not None: - logger.info( - "Not bothering to persist state event %s duplicated by %s", - event.event_id, - prev_event.event_id, - ) - return prev_event + if event.is_state(): + prev_event = await self.deduplicate_state_event(event, context) + if prev_event is not None: + logger.info( + "Not bothering to persist state event %s duplicated by %s", + event.event_id, + prev_event.event_id, + ) + return prev_event - if event.internal_metadata.is_out_of_band_membership(): - # the only sort of out-of-band-membership events we expect to see here are - # invite rejections and rescinded knocks that we have generated ourselves. - assert event.type == EventTypes.Member - assert event.content["membership"] == Membership.LEAVE - else: - try: - validate_event_for_room_version(event) - await self._event_auth_handler.check_auth_rules_from_context( - event, context - ) - except AuthError as err: - logger.warning("Denying new event %r because %s", event, err) - raise err + if event.internal_metadata.is_out_of_band_membership(): + # the only sort of out-of-band-membership events we expect to see here are + # invite rejections and rescinded knocks that we have generated ourselves. + assert event.type == EventTypes.Member + assert event.content["membership"] == Membership.LEAVE + else: + try: + validate_event_for_room_version(event) + await self._event_auth_handler.check_auth_rules_from_context( + event, context + ) + except AuthError as err: + logger.warning("Denying new event %r because %s", event, err) + raise err - # Ensure that we can round trip before trying to persist in db - try: - dump = json_encoder.encode(event.content) - json_decoder.decode(dump) - except Exception: - logger.exception("Failed to encode content: %r", event.content) - raise + # Ensure that we can round trip before trying to persist in db + try: + dump = json_encoder.encode(event.content) + json_decoder.decode(dump) + except Exception: + logger.exception("Failed to encode content: %r", event.content) + raise # We now persist the event (and update the cache in parallel, since we # don't want to block on it). + event, context = events_and_context[0] try: result, _ = await make_deferred_yieldable( gather_results( ( run_in_background( - self._persist_event, + self._persist_events, requester=requester, - event=event, - context=context, + events_and_context=events_and_context, ratelimit=ratelimit, extra_users=extra_users, ), @@ -1407,45 +1409,47 @@ class EventCreationHandler: return result - async def _persist_event( + async def _persist_events( self, requester: Requester, - event: EventBase, - context: EventContext, + events_and_context: List[Tuple[EventBase, EventContext]], ratelimit: bool = True, extra_users: Optional[List[UserID]] = None, ) -> EventBase: - """Actually persists the event. Should only be called by + """Actually persists new events. Should only be called by `handle_new_client_event`, and see its docstring for documentation of - the arguments. + the arguments. Please note that if batch persisting events, an error in + handling any one of these events will result in all of the events being dropped. PartialStateConflictError: if attempting to persist a partial state event in a room that has been un-partial stated. """ - # Skip push notification actions for historical messages - # because we don't want to notify people about old history back in time. - # The historical messages also do not have the proper `context.current_state_ids` - # and `state_groups` because they have `prev_events` that aren't persisted yet - # (historical messages persisted in reverse-chronological order). - if not event.internal_metadata.is_historical(): - with opentracing.start_active_span("calculate_push_actions"): - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context - ) + for event, context in events_and_context: + # Skip push notification actions for historical messages + # because we don't want to notify people about old history back in time. + # The historical messages also do not have the proper `context.current_state_ids` + # and `state_groups` because they have `prev_events` that aren't persisted yet + # (historical messages persisted in reverse-chronological order). + if not event.internal_metadata.is_historical(): + with opentracing.start_active_span("calculate_push_actions"): + await self._bulk_push_rule_evaluator.action_for_event_by_user( + event, context + ) try: # If we're a worker we need to hit out to the master. - writer_instance = self._events_shard_config.get_instance(event.room_id) + first_event, _ = events_and_context[0] + writer_instance = self._events_shard_config.get_instance( + first_event.room_id + ) if writer_instance != self._instance_name: try: - result = await self.send_event( + result = await self.send_events( instance_name=writer_instance, - event_id=event.event_id, + events_and_context=events_and_context, store=self.store, requester=requester, - event=event, - context=context, ratelimit=ratelimit, extra_users=extra_users, ) @@ -1455,6 +1459,11 @@ class EventCreationHandler: raise stream_id = result["stream_id"] event_id = result["event_id"] + + # If we batch persisted events we return the last persisted event, otherwise + # we return the one event that was persisted + event, _ = events_and_context[-1] + if event_id != event.event_id: # If we get a different event back then it means that its # been de-duplicated, so we replace the given event with the @@ -1467,15 +1476,19 @@ class EventCreationHandler: event.internal_metadata.stream_ordering = stream_id return event - event = await self.persist_and_notify_client_event( - requester, event, context, ratelimit=ratelimit, extra_users=extra_users + event = await self.persist_and_notify_client_events( + requester, + events_and_context, + ratelimit=ratelimit, + extra_users=extra_users, ) return event except Exception: - # Ensure that we actually remove the entries in the push actions - # staging area, if we calculated them. - await self.store.remove_push_actions_from_staging(event.event_id) + for event, _ in events_and_context: + # Ensure that we actually remove the entries in the push actions + # staging area, if we calculated them. + await self.store.remove_push_actions_from_staging(event.event_id) raise async def cache_joined_hosts_for_event( @@ -1569,23 +1582,26 @@ class EventCreationHandler: Codes.BAD_ALIAS, ) - async def persist_and_notify_client_event( + async def persist_and_notify_client_events( self, requester: Requester, - event: EventBase, - context: EventContext, + events_and_context: List[Tuple[EventBase, EventContext]], ratelimit: bool = True, extra_users: Optional[List[UserID]] = None, ) -> EventBase: - """Called when we have fully built the event, have already - calculated the push actions for the event, and checked auth. + """Called when we have fully built the events, have already + calculated the push actions for the events, and checked auth. This should only be run on the instance in charge of persisting events. + Please note that if batch persisting events, an error in + handling any one of these events will result in all of the events being dropped. + Returns: - The persisted event. This may be different than the given event if - it was de-duplicated (e.g. because we had already persisted an - event with the same transaction ID.) + The persisted event, if one event is passed in, or the last event in the + list in the case of batch persisting. If only one event was persisted, the + returned event may be different than the given event if it was de-duplicated + (e.g. because we had already persisted an event with the same transaction ID.) Raises: PartialStateConflictError: if attempting to persist a partial state event in @@ -1593,277 +1609,297 @@ class EventCreationHandler: """ extra_users = extra_users or [] - assert self._storage_controllers.persistence is not None - assert self._events_shard_config.should_handle( - self._instance_name, event.room_id - ) + for event, context in events_and_context: + assert self._events_shard_config.should_handle( + self._instance_name, event.room_id + ) - if ratelimit: - # We check if this is a room admin redacting an event so that we - # can apply different ratelimiting. We do this by simply checking - # it's not a self-redaction (to avoid having to look up whether the - # user is actually admin or not). - is_admin_redaction = False - if event.type == EventTypes.Redaction: - assert event.redacts is not None + if ratelimit: + # We check if this is a room admin redacting an event so that we + # can apply different ratelimiting. We do this by simply checking + # it's not a self-redaction (to avoid having to look up whether the + # user is actually admin or not). + is_admin_redaction = False + if event.type == EventTypes.Redaction: + assert event.redacts is not None + + original_event = await self.store.get_event( + event.redacts, + redact_behaviour=EventRedactBehaviour.as_is, + get_prev_content=False, + allow_rejected=False, + allow_none=True, + ) - original_event = await self.store.get_event( - event.redacts, - redact_behaviour=EventRedactBehaviour.as_is, - get_prev_content=False, - allow_rejected=False, - allow_none=True, + is_admin_redaction = bool( + original_event and event.sender != original_event.sender + ) + + await self.request_ratelimiter.ratelimit( + requester, is_admin_redaction=is_admin_redaction ) - is_admin_redaction = bool( - original_event and event.sender != original_event.sender + # run checks/actions on event based on type + if event.type == EventTypes.Member and event.membership == Membership.JOIN: + ( + current_membership, + _, + ) = await self.store.get_local_current_membership_for_user_in_room( + event.state_key, event.room_id ) + if current_membership != Membership.JOIN: + self._notifier.notify_user_joined_room( + event.event_id, event.room_id + ) - await self.request_ratelimiter.ratelimit( - requester, is_admin_redaction=is_admin_redaction - ) + await self._maybe_kick_guest_users(event, context) - if event.type == EventTypes.Member and event.membership == Membership.JOIN: - ( - current_membership, - _, - ) = await self.store.get_local_current_membership_for_user_in_room( - event.state_key, event.room_id - ) - if current_membership != Membership.JOIN: - self._notifier.notify_user_joined_room(event.event_id, event.room_id) + if event.type == EventTypes.CanonicalAlias: + # Validate a newly added alias or newly added alt_aliases. - await self._maybe_kick_guest_users(event, context) + original_alias = None + original_alt_aliases: object = [] - if event.type == EventTypes.CanonicalAlias: - # Validate a newly added alias or newly added alt_aliases. + original_event_id = event.unsigned.get("replaces_state") + if original_event_id: + original_alias_event = await self.store.get_event(original_event_id) - original_alias = None - original_alt_aliases: object = [] + if original_alias_event: + original_alias = original_alias_event.content.get("alias", None) + original_alt_aliases = original_alias_event.content.get( + "alt_aliases", [] + ) - original_event_id = event.unsigned.get("replaces_state") - if original_event_id: - original_event = await self.store.get_event(original_event_id) + # Check the alias is currently valid (if it has changed). + room_alias_str = event.content.get("alias", None) + directory_handler = self.hs.get_directory_handler() + if room_alias_str and room_alias_str != original_alias: + await self._validate_canonical_alias( + directory_handler, room_alias_str, event.room_id + ) - if original_event: - original_alias = original_event.content.get("alias", None) - original_alt_aliases = original_event.content.get("alt_aliases", []) - - # Check the alias is currently valid (if it has changed). - room_alias_str = event.content.get("alias", None) - directory_handler = self.hs.get_directory_handler() - if room_alias_str and room_alias_str != original_alias: - await self._validate_canonical_alias( - directory_handler, room_alias_str, event.room_id - ) + # Check that alt_aliases is the proper form. + alt_aliases = event.content.get("alt_aliases", []) + if not isinstance(alt_aliases, (list, tuple)): + raise SynapseError( + 400, + "The alt_aliases property must be a list.", + Codes.INVALID_PARAM, + ) - # Check that alt_aliases is the proper form. - alt_aliases = event.content.get("alt_aliases", []) - if not isinstance(alt_aliases, (list, tuple)): - raise SynapseError( - 400, "The alt_aliases property must be a list.", Codes.INVALID_PARAM - ) + # If the old version of alt_aliases is of an unknown form, + # completely replace it. + if not isinstance(original_alt_aliases, (list, tuple)): + # TODO: check that the original_alt_aliases' entries are all strings + original_alt_aliases = [] + + # Check that each alias is currently valid. + new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) + if new_alt_aliases: + for alias_str in new_alt_aliases: + await self._validate_canonical_alias( + directory_handler, alias_str, event.room_id + ) - # If the old version of alt_aliases is of an unknown form, - # completely replace it. - if not isinstance(original_alt_aliases, (list, tuple)): - # TODO: check that the original_alt_aliases' entries are all strings - original_alt_aliases = [] + federation_handler = self.hs.get_federation_handler() - # Check that each alias is currently valid. - new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) - if new_alt_aliases: - for alias_str in new_alt_aliases: - await self._validate_canonical_alias( - directory_handler, alias_str, event.room_id + if event.type == EventTypes.Member: + if event.content["membership"] == Membership.INVITE: + event.unsigned[ + "invite_room_state" + ] = await self.store.get_stripped_room_state_from_event_context( + context, + self.room_prejoin_state_types, + membership_user_id=event.sender, ) - federation_handler = self.hs.get_federation_handler() + invitee = UserID.from_string(event.state_key) + if not self.hs.is_mine(invitee): + # TODO: Can we add signature from remote server in a nicer + # way? If we have been invited by a remote server, we need + # to get them to sign the event. - if event.type == EventTypes.Member: - if event.content["membership"] == Membership.INVITE: - event.unsigned[ - "invite_room_state" - ] = await self.store.get_stripped_room_state_from_event_context( - context, - self.room_prejoin_state_types, - membership_user_id=event.sender, - ) + returned_invite = await federation_handler.send_invite( + invitee.domain, event + ) + event.unsigned.pop("room_state", None) - invitee = UserID.from_string(event.state_key) - if not self.hs.is_mine(invitee): - # TODO: Can we add signature from remote server in a nicer - # way? If we have been invited by a remote server, we need - # to get them to sign the event. + # TODO: Make sure the signatures actually are correct. + event.signatures.update(returned_invite.signatures) - returned_invite = await federation_handler.send_invite( - invitee.domain, event + if event.content["membership"] == Membership.KNOCK: + event.unsigned[ + "knock_room_state" + ] = await self.store.get_stripped_room_state_from_event_context( + context, + self.room_prejoin_state_types, ) - event.unsigned.pop("room_state", None) - # TODO: Make sure the signatures actually are correct. - event.signatures.update(returned_invite.signatures) + if event.type == EventTypes.Redaction: + assert event.redacts is not None - if event.content["membership"] == Membership.KNOCK: - event.unsigned[ - "knock_room_state" - ] = await self.store.get_stripped_room_state_from_event_context( - context, - self.room_prejoin_state_types, + original_event = await self.store.get_event( + event.redacts, + redact_behaviour=EventRedactBehaviour.as_is, + get_prev_content=False, + allow_rejected=False, + allow_none=True, ) - if event.type == EventTypes.Redaction: - assert event.redacts is not None + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - original_event = await self.store.get_event( - event.redacts, - redact_behaviour=EventRedactBehaviour.as_is, - get_prev_content=False, - allow_rejected=False, - allow_none=True, - ) + # we can make some additional checks now if we have the original event. + if original_event: + if original_event.type == EventTypes.Create: + raise AuthError(403, "Redacting create events is not permitted") - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - - # we can make some additional checks now if we have the original event. - if original_event: - if original_event.type == EventTypes.Create: - raise AuthError(403, "Redacting create events is not permitted") - - if original_event.room_id != event.room_id: - raise SynapseError(400, "Cannot redact event from a different room") - - if original_event.type == EventTypes.ServerACL: - raise AuthError(403, "Redacting server ACL events is not permitted") - - # Add a little safety stop-gap to prevent people from trying to - # redact MSC2716 related events when they're in a room version - # which does not support it yet. We allow people to use MSC2716 - # events in existing room versions but only from the room - # creator since it does not require any changes to the auth - # rules and in effect, the redaction algorithm . In the - # supported room version, we add the `historical` power level to - # auth the MSC2716 related events and adjust the redaction - # algorthim to keep the `historical` field around (redacting an - # event should only strip fields which don't affect the - # structural protocol level). - is_msc2716_event = ( - original_event.type == EventTypes.MSC2716_INSERTION - or original_event.type == EventTypes.MSC2716_BATCH - or original_event.type == EventTypes.MSC2716_MARKER - ) - if not room_version_obj.msc2716_historical and is_msc2716_event: - raise AuthError( - 403, - "Redacting MSC2716 events is not supported in this room version", - ) + if original_event.room_id != event.room_id: + raise SynapseError( + 400, "Cannot redact event from a different room" + ) - event_types = event_auth.auth_types_for_event(event.room_version, event) - prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types(event_types) - ) + if original_event.type == EventTypes.ServerACL: + raise AuthError( + 403, "Redacting server ACL events is not permitted" + ) - auth_events_ids = self._event_auth_handler.compute_auth_events( - event, prev_state_ids, for_verification=True - ) - auth_events_map = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()} + # Add a little safety stop-gap to prevent people from trying to + # redact MSC2716 related events when they're in a room version + # which does not support it yet. We allow people to use MSC2716 + # events in existing room versions but only from the room + # creator since it does not require any changes to the auth + # rules and in effect, the redaction algorithm . In the + # supported room version, we add the `historical` power level to + # auth the MSC2716 related events and adjust the redaction + # algorthim to keep the `historical` field around (redacting an + # event should only strip fields which don't affect the + # structural protocol level). + is_msc2716_event = ( + original_event.type == EventTypes.MSC2716_INSERTION + or original_event.type == EventTypes.MSC2716_BATCH + or original_event.type == EventTypes.MSC2716_MARKER + ) + if not room_version_obj.msc2716_historical and is_msc2716_event: + raise AuthError( + 403, + "Redacting MSC2716 events is not supported in this room version", + ) - if event_auth.check_redaction( - room_version_obj, event, auth_events=auth_events - ): - # this user doesn't have 'redact' rights, so we need to do some more - # checks on the original event. Let's start by checking the original - # event exists. - if not original_event: - raise NotFoundError("Could not find event %s" % (event.redacts,)) - - if event.user_id != original_event.user_id: - raise AuthError(403, "You don't have permission to redact events") - - # all the checks are done. - event.internal_metadata.recheck_redaction = False - - if event.type == EventTypes.Create: - prev_state_ids = await context.get_prev_state_ids() - if prev_state_ids: - raise AuthError(403, "Changing the room create event is forbidden") - - if event.type == EventTypes.MSC2716_INSERTION: - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - - create_event = await self.store.get_create_event_for_room(event.room_id) - room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) - - # Only check an insertion event if the room version - # supports it or the event is from the room creator. - if room_version_obj.msc2716_historical or ( - self.config.experimental.msc2716_enabled - and event.sender == room_creator - ): - next_batch_id = event.content.get( - EventContentFields.MSC2716_NEXT_BATCH_ID + event_types = event_auth.auth_types_for_event(event.room_version, event) + prev_state_ids = await context.get_prev_state_ids( + StateFilter.from_types(event_types) ) - conflicting_insertion_event_id = None - if next_batch_id: - conflicting_insertion_event_id = ( - await self.store.get_insertion_event_id_by_batch_id( - event.room_id, next_batch_id + + auth_events_ids = self._event_auth_handler.compute_auth_events( + event, prev_state_ids, for_verification=True + ) + auth_events_map = await self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events_map.values() + } + + if event_auth.check_redaction( + room_version_obj, event, auth_events=auth_events + ): + # this user doesn't have 'redact' rights, so we need to do some more + # checks on the original event. Let's start by checking the original + # event exists. + if not original_event: + raise NotFoundError( + "Could not find event %s" % (event.redacts,) ) + + if event.user_id != original_event.user_id: + raise AuthError( + 403, "You don't have permission to redact events" + ) + + # all the checks are done. + event.internal_metadata.recheck_redaction = False + + if event.type == EventTypes.Create: + prev_state_ids = await context.get_prev_state_ids() + if prev_state_ids: + raise AuthError(403, "Changing the room create event is forbidden") + + if event.type == EventTypes.MSC2716_INSERTION: + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + create_event = await self.store.get_create_event_for_room(event.room_id) + room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) + + # Only check an insertion event if the room version + # supports it or the event is from the room creator. + if room_version_obj.msc2716_historical or ( + self.config.experimental.msc2716_enabled + and event.sender == room_creator + ): + next_batch_id = event.content.get( + EventContentFields.MSC2716_NEXT_BATCH_ID ) - if conflicting_insertion_event_id is not None: - # The current insertion event that we're processing is invalid - # because an insertion event already exists in the room with the - # same next_batch_id. We can't allow multiple because the batch - # pointing will get weird, e.g. we can't determine which insertion - # event the batch event is pointing to. - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Another insertion event already exists with the same next_batch_id", - errcode=Codes.INVALID_PARAM, - ) + conflicting_insertion_event_id = None + if next_batch_id: + conflicting_insertion_event_id = ( + await self.store.get_insertion_event_id_by_batch_id( + event.room_id, next_batch_id + ) + ) + if conflicting_insertion_event_id is not None: + # The current insertion event that we're processing is invalid + # because an insertion event already exists in the room with the + # same next_batch_id. We can't allow multiple because the batch + # pointing will get weird, e.g. we can't determine which insertion + # event the batch event is pointing to. + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Another insertion event already exists with the same next_batch_id", + errcode=Codes.INVALID_PARAM, + ) - # Mark any `m.historical` messages as backfilled so they don't appear - # in `/sync` and have the proper decrementing `stream_ordering` as we import - backfilled = False - if event.internal_metadata.is_historical(): - backfilled = True + # Mark any `m.historical` messages as backfilled so they don't appear + # in `/sync` and have the proper decrementing `stream_ordering` as we import + backfilled = False + if event.internal_metadata.is_historical(): + backfilled = True - # Note that this returns the event that was persisted, which may not be - # the same as we passed in if it was deduplicated due transaction IDs. + assert self._storage_controllers.persistence is not None ( - event, - event_pos, + persisted_events, max_stream_token, - ) = await self._storage_controllers.persistence.persist_event( - event, context=context, backfilled=backfilled + ) = await self._storage_controllers.persistence.persist_events( + events_and_context, backfilled=backfilled ) - if self._ephemeral_events_enabled: - # If there's an expiry timestamp on the event, schedule its expiry. - self._message_handler.maybe_schedule_expiry(event) + for event in persisted_events: + if self._ephemeral_events_enabled: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) - async def _notify() -> None: - try: - await self.notifier.on_new_room_event( - event, event_pos, max_stream_token, extra_users=extra_users - ) - except Exception: - logger.exception( - "Error notifying about new room event %s", - event.event_id, - ) + stream_ordering = event.internal_metadata.stream_ordering + assert stream_ordering is not None + pos = PersistedEventPosition(self._instance_name, stream_ordering) + + async def _notify() -> None: + try: + await self.notifier.on_new_room_event( + event, pos, max_stream_token, extra_users=extra_users + ) + except Exception: + logger.exception( + "Error notifying about new room event %s", + event.event_id, + ) - run_in_background(_notify) + run_in_background(_notify) - if event.type == EventTypes.Message: - # We don't want to block sending messages on any presence code. This - # matters as sometimes presence code can take a while. - run_in_background(self._bump_active_time, requester.user) + if event.type == EventTypes.Message: + # We don't want to block sending messages on any presence code. This + # matters as sometimes presence code can take a while. + run_in_background(self._bump_active_time, requester.user) - return event + return persisted_events[-1] async def _maybe_kick_guest_users( self, event: EventBase, context: EventContext @@ -1952,8 +1988,7 @@ class EventCreationHandler: # shadow-banned user. await self.handle_new_client_event( requester, - event, - context, + events_and_context=[(event, context)], ratelimit=False, ignore_shadow_ban=True, ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 09a1a82e6c..b220238e55 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -301,8 +301,7 @@ class RoomCreationHandler: # now send the tombstone await self.event_creation_handler.handle_new_client_event( requester=requester, - event=tombstone_event, - context=tombstone_context, + events_and_context=[(tombstone_event, tombstone_context)], ) state_filter = StateFilter.from_types( @@ -1057,8 +1056,10 @@ class RoomCreationHandler: creator_id = creator.user.to_string() event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} depth = 1 + # the last event sent/persisted to the db last_sent_event_id: Optional[str] = None + # the most recently created event prev_event: List[str] = [] # a map of event types, state keys -> event_ids. We collect these mappings this as events are @@ -1112,8 +1113,7 @@ class RoomCreationHandler: ev = await self.event_creation_handler.handle_new_client_event( requester=creator, - event=event, - context=context, + events_and_context=[(event, context)], ratelimit=False, ignore_shadow_ban=True, ) @@ -1152,7 +1152,6 @@ class RoomCreationHandler: prev_event_ids=[last_sent_event_id], depth=depth, ) - last_sent_event_id = member_event_id prev_event = [member_event_id] # update the depth and state map here as the membership event has been created @@ -1168,7 +1167,7 @@ class RoomCreationHandler: EventTypes.PowerLevels, pl_content, False ) current_state_group = power_context._state_group - last_sent_stream_id = await send(power_event, power_context, creator) + await send(power_event, power_context, creator) else: power_level_content: JsonDict = { "users": {creator_id: 100}, @@ -1217,7 +1216,7 @@ class RoomCreationHandler: False, ) current_state_group = pl_context._state_group - last_sent_stream_id = await send(pl_event, pl_context, creator) + await send(pl_event, pl_context, creator) events_to_send = [] if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: @@ -1271,9 +1270,11 @@ class RoomCreationHandler: ) events_to_send.append((encryption_event, encryption_context)) - for event, context in events_to_send: - last_sent_stream_id = await send(event, context, creator) - return last_sent_stream_id, last_sent_event_id, depth + last_event = await self.event_creation_handler.handle_new_client_event( + creator, events_to_send, ignore_shadow_ban=True + ) + assert last_event.internal_metadata.stream_ordering is not None + return last_event.internal_metadata.stream_ordering, last_event.event_id, depth def _generate_room_id(self) -> str: """Generates a random room ID. diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 1414e575d6..411a6fb22f 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -379,8 +379,7 @@ class RoomBatchHandler: await self.create_requester_for_user_id_from_app_service( event.sender, app_service_requester.app_service ), - event=event, - context=context, + events_and_context=[(event, context)], ) return event_ids diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8d01f4bf2b..88158822e0 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -432,8 +432,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): with opentracing.start_active_span("handle_new_client_event"): result_event = await self.event_creation_handler.handle_new_client_event( requester, - event, - context, + events_and_context=[(event, context)], extra_users=[target], ratelimit=ratelimit, ) @@ -1252,7 +1251,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise SynapseError(403, "This room has been blocked on this server") event = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[target_user], ratelimit=ratelimit + requester, + events_and_context=[(event, context)], + extra_users=[target_user], + ratelimit=ratelimit, ) prev_member_event_id = prev_state_ids.get( @@ -1860,8 +1862,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): result_event = await self.event_creation_handler.handle_new_client_event( requester, - event, - context, + events_and_context=[(event, context)], extra_users=[UserID.from_string(target_user)], ) # we know it was persisted, so must have a stream ordering diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 53aa7fa4c6..ac9a92240a 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -25,6 +25,7 @@ from synapse.replication.http import ( push, register, send_event, + send_events, state, streams, ) @@ -43,6 +44,7 @@ class ReplicationRestResource(JsonResource): def register_servlets(self, hs: "HomeServer") -> None: send_event.register_servlets(hs, self) + send_events.register_servlets(hs, self) federation.register_servlets(hs, self) presence.register_servlets(hs, self) membership.register_servlets(hs, self) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 486f04723c..4215a1c1bc 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -141,8 +141,8 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - event = await self.event_creation_handler.persist_and_notify_client_event( - requester, event, context, ratelimit=ratelimit, extra_users=extra_users + event = await self.event_creation_handler.persist_and_notify_client_events( + requester, [(event, context)], ratelimit=ratelimit, extra_users=extra_users ) return ( diff --git a/synapse/replication/http/send_events.py b/synapse/replication/http/send_events.py new file mode 100644 index 0000000000..8889bbb644 --- /dev/null +++ b/synapse/replication/http/send_events.py @@ -0,0 +1,171 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, List, Tuple + +from twisted.web.server import Request + +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase, make_event_from_dict +from synapse.events.snapshot import EventContext +from synapse.http.server import HttpServer +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import JsonDict, Requester, UserID +from synapse.util.metrics import Measure + +if TYPE_CHECKING: + from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore + +logger = logging.getLogger(__name__) + + +class ReplicationSendEventsRestServlet(ReplicationEndpoint): + """Handles batches of newly created events on workers, including persisting and + notifying. + + The API looks like: + + POST /_synapse/replication/send_events/:txn_id + + { + "events": [{ + "event": { .. serialized event .. }, + "room_version": .., // "1", "2", "3", etc: the version of the room + // containing the event + "event_format_version": .., // 1,2,3 etc: the event format version + "internal_metadata": { .. serialized internal_metadata .. }, + "outlier": true|false, + "rejected_reason": .., // The event.rejected_reason field + "context": { .. serialized event context .. }, + "requester": { .. serialized requester .. }, + "ratelimit": true, + }] + } + + 200 OK + + { "stream_id": 12345, "event_id": "$abcdef..." } + + Responds with a 409 when a `PartialStateConflictError` is raised due to an event + context that needs to be recomputed due to the un-partial stating of a room. + + """ + + NAME = "send_events" + PATH_ARGS = () + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.event_creation_handler = hs.get_event_creation_handler() + self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload( # type: ignore[override] + events_and_context: List[Tuple[EventBase, EventContext]], + store: "DataStore", + requester: Requester, + ratelimit: bool, + extra_users: List[UserID], + ) -> JsonDict: + """ + Args: + store + requester + events_and_ctx + ratelimit + """ + serialized_events = [] + + for event, context in events_and_context: + serialized_context = await context.serialize(event, store) + serialized_event = { + "event": event.get_pdu_json(), + "room_version": event.room_version.identifier, + "event_format_version": event.format_version, + "internal_metadata": event.internal_metadata.get_dict(), + "outlier": event.internal_metadata.is_outlier(), + "rejected_reason": event.rejected_reason, + "context": serialized_context, + "requester": requester.serialize(), + "ratelimit": ratelimit, + "extra_users": [u.to_string() for u in extra_users], + } + serialized_events.append(serialized_event) + + payload = {"events": serialized_events} + + return payload + + async def _handle_request( # type: ignore[override] + self, request: Request + ) -> Tuple[int, JsonDict]: + with Measure(self.clock, "repl_send_events_parse"): + payload = parse_json_object_from_request(request) + events_and_context = [] + events = payload["events"] + + for event_payload in events: + event_dict = event_payload["event"] + room_ver = KNOWN_ROOM_VERSIONS[event_payload["room_version"]] + internal_metadata = event_payload["internal_metadata"] + rejected_reason = event_payload["rejected_reason"] + + event = make_event_from_dict( + event_dict, room_ver, internal_metadata, rejected_reason + ) + event.internal_metadata.outlier = event_payload["outlier"] + + requester = Requester.deserialize( + self.store, event_payload["requester"] + ) + context = EventContext.deserialize( + self._storage_controllers, event_payload["context"] + ) + + ratelimit = event_payload["ratelimit"] + events_and_context.append((event, context)) + + extra_users = [ + UserID.from_string(u) for u in event_payload["extra_users"] + ] + + logger.info( + "Got batch of events to send, last ID of batch is: %s, sending into room: %s", + event.event_id, + event.room_id, + ) + + last_event = ( + await self.event_creation_handler.persist_and_notify_client_events( + requester, events_and_context, ratelimit, extra_users + ) + ) + + return ( + 200, + { + "stream_id": last_event.internal_metadata.stream_ordering, + "event_id": last_event.event_id, + }, + ) + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + ReplicationSendEventsRestServlet(hs).register(http_server) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 986b50ce0c..99384837d0 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -105,7 +105,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase): event1, context = self._create_duplicate_event(txn_id) ret_event1 = self.get_success( - self.handler.handle_new_client_event(self.requester, event1, context) + self.handler.handle_new_client_event( + self.requester, + events_and_context=[(event1, context)], + ) ) stream_id1 = ret_event1.internal_metadata.stream_ordering @@ -118,7 +121,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertNotEqual(event1.event_id, event2.event_id) ret_event2 = self.get_success( - self.handler.handle_new_client_event(self.requester, event2, context) + self.handler.handle_new_client_event( + self.requester, + events_and_context=[(event2, context)], + ) ) stream_id2 = ret_event2.internal_metadata.stream_ordering diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 86b3d51975..765df75d91 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -497,7 +497,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - event_creation_handler.handle_new_client_event(requester, event, context) + event_creation_handler.handle_new_client_event( + requester, events_and_context=[(event, context)] + ) ) # Register a second user, which won't be be in the room (or even have an invite) diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index a0ce077a99..de9f4af2de 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -531,7 +531,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ) ) self.get_success( - event_handler.handle_new_client_event(self.requester, event, context) + event_handler.handle_new_client_event( + self.requester, events_and_context=[(event, context)] + ) ) state1 = set(self.get_success(context.get_current_state_ids()).values()) @@ -549,7 +551,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ) ) self.get_success( - event_handler.handle_new_client_event(self.requester, event, context) + event_handler.handle_new_client_event( + self.requester, events_and_context=[(event, context)] + ) ) state2 = set(self.get_success(context.get_current_state_ids()).values()) diff --git a/tests/unittest.py b/tests/unittest.py index 00cb023198..5116be338e 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -734,7 +734,9 @@ class HomeserverTestCase(TestCase): event.internal_metadata.soft_failed = True self.get_success( - event_creator.handle_new_client_event(requester, event, context) + event_creator.handle_new_client_event( + requester, events_and_context=[(event, context)] + ) ) return event.event_id -- cgit 1.5.1 From df8b91ed2bba4995c59a5b067e3b252ab90c9a5e Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 28 Sep 2022 15:26:16 -0500 Subject: Limit and filter the number of backfill points to get from the database (#13879) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit There is no need to grab thousands of backfill points when we only need 5 to make the `/backfill` request with. We need to grab a few extra in case the first few aren't visible in the history. Previously, we grabbed thousands of backfill points from the database, then sorted and filtered them in the app. Fetching the 4.6k backfill points for `#matrix:matrix.org` from the database takes ~50ms - ~570ms so it's not like this saves a lot of time 🤷. But it might save us more time now that `get_backfill_points_in_room`/`get_insertion_event_backward_extremities_in_room` are more complicated after https://github.com/matrix-org/synapse/pull/13635 This PR moves the filtering and limiting to the SQL query so we just have less data to work with in the first place. Part of https://github.com/matrix-org/synapse/issues/13356 --- changelog.d/13879.misc | 1 + synapse/handlers/federation.py | 109 ++++++++++++--------- synapse/storage/databases/main/event_federation.py | 90 ++++++++++++++--- tests/storage/test_event_federation.py | 80 ++++++++++----- 4 files changed, 198 insertions(+), 82 deletions(-) create mode 100644 changelog.d/13879.misc (limited to 'tests/storage') diff --git a/changelog.d/13879.misc b/changelog.d/13879.misc new file mode 100644 index 0000000000..3cc2a2420f --- /dev/null +++ b/changelog.d/13879.misc @@ -0,0 +1 @@ +Only pull relevant backfill points from the database based on the current depth and limit (instead of all) every time we want to `/backfill`. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 360ab6fee2..500c1c16d0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -38,7 +38,7 @@ from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 from synapse import event_auth -from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership from synapse.api.errors import ( AuthError, CodeMessageException, @@ -211,7 +211,7 @@ class FederationHandler: current_depth: int, limit: int, *, - processing_start_time: int, + processing_start_time: Optional[int], ) -> bool: """ Checks whether the `current_depth` is at or approaching any backfill @@ -223,12 +223,23 @@ class FederationHandler: room_id: The room to backfill in. current_depth: The depth to check at for any upcoming backfill points. limit: The max number of events to request from the remote federated server. - processing_start_time: The time when `maybe_backfill` started - processing. Only used for timing. + processing_start_time: The time when `maybe_backfill` started processing. + Only used for timing. If `None`, no timing observation will be made. """ backwards_extremities = [ _BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY) - for event_id, depth in await self.store.get_backfill_points_in_room(room_id) + for event_id, depth in await self.store.get_backfill_points_in_room( + room_id=room_id, + current_depth=current_depth, + # We only need to end up with 5 extremities combined with the + # insertion event extremities to make the `/backfill` request + # but fetch an order of magnitude more to make sure there is + # enough even after we filter them by whether visible in the + # history. This isn't fool-proof as all backfill points within + # our limit could be filtered out but seems like a good amount + # to try with at least. + limit=50, + ) ] insertion_events_to_be_backfilled: List[_BackfillPoint] = [] @@ -236,7 +247,12 @@ class FederationHandler: insertion_events_to_be_backfilled = [ _BackfillPoint(event_id, depth, _BackfillPointType.INSERTION_PONT) for event_id, depth in await self.store.get_insertion_event_backward_extremities_in_room( - room_id + room_id=room_id, + current_depth=current_depth, + # We only need to end up with 5 extremities combined with + # the backfill points to make the `/backfill` request ... + # (see the other comment above for more context). + limit=50, ) ] logger.debug( @@ -245,10 +261,6 @@ class FederationHandler: insertion_events_to_be_backfilled, ) - if not backwards_extremities and not insertion_events_to_be_backfilled: - logger.debug("Not backfilling as no extremeties found.") - return False - # we now have a list of potential places to backpaginate from. We prefer to # start with the most recent (ie, max depth), so let's sort the list. sorted_backfill_points: List[_BackfillPoint] = sorted( @@ -269,6 +281,33 @@ class FederationHandler: sorted_backfill_points, ) + # If we have no backfill points lower than the `current_depth` then + # either we can a) bail or b) still attempt to backfill. We opt to try + # backfilling anyway just in case we do get relevant events. + if not sorted_backfill_points and current_depth != MAX_DEPTH: + logger.debug( + "_maybe_backfill_inner: all backfill points are *after* current depth. Trying again with later backfill points." + ) + return await self._maybe_backfill_inner( + room_id=room_id, + # We use `MAX_DEPTH` so that we find all backfill points next + # time (all events are below the `MAX_DEPTH`) + current_depth=MAX_DEPTH, + limit=limit, + # We don't want to start another timing observation from this + # nested recursive call. The top-most call can record the time + # overall otherwise the smaller one will throw off the results. + processing_start_time=None, + ) + + # Even after recursing with `MAX_DEPTH`, we didn't find any + # backward extremities to backfill from. + if not sorted_backfill_points: + logger.debug( + "_maybe_backfill_inner: Not backfilling as no backward extremeties found." + ) + return False + # If we're approaching an extremity we trigger a backfill, otherwise we # no-op. # @@ -278,47 +317,16 @@ class FederationHandler: # chose more than one times the limit in case of failure, but choosing a # much larger factor will result in triggering a backfill request much # earlier than necessary. - # - # XXX: shouldn't we do this *after* the filter by depth below? Again, we don't - # care about events that have happened after our current position. - # - max_depth = sorted_backfill_points[0].depth - if current_depth - 2 * limit > max_depth: + max_depth_of_backfill_points = sorted_backfill_points[0].depth + if current_depth - 2 * limit > max_depth_of_backfill_points: logger.debug( "Not backfilling as we don't need to. %d < %d - 2 * %d", - max_depth, + max_depth_of_backfill_points, current_depth, limit, ) return False - # We ignore extremities that have a greater depth than our current depth - # as: - # 1. we don't really care about getting events that have happened - # after our current position; and - # 2. we have likely previously tried and failed to backfill from that - # extremity, so to avoid getting "stuck" requesting the same - # backfill repeatedly we drop those extremities. - # - # However, we need to check that the filtered extremities are non-empty. - # If they are empty then either we can a) bail or b) still attempt to - # backfill. We opt to try backfilling anyway just in case we do get - # relevant events. - # - filtered_sorted_backfill_points = [ - t for t in sorted_backfill_points if t.depth <= current_depth - ] - if filtered_sorted_backfill_points: - logger.debug( - "_maybe_backfill_inner: backfill points before current depth: %s", - filtered_sorted_backfill_points, - ) - sorted_backfill_points = filtered_sorted_backfill_points - else: - logger.debug( - "_maybe_backfill_inner: all backfill points are *after* current depth. Backfilling anyway." - ) - # For performance's sake, we only want to paginate from a particular extremity # if we can actually see the events we'll get. Otherwise, we'd just spend a lot # of resources to get redacted events. We check each extremity in turn and @@ -452,10 +460,15 @@ class FederationHandler: return False - processing_end_time = self.clock.time_msec() - backfill_processing_before_timer.observe( - (processing_end_time - processing_start_time) / 1000 - ) + # If we have the `processing_start_time`, then we can make an + # observation. We wouldn't have the `processing_start_time` in the case + # where `_maybe_backfill_inner` is recursively called to find any + # backfill points regardless of `current_depth`. + if processing_start_time is not None: + processing_end_time = self.clock.time_msec() + backfill_processing_before_timer.observe( + (processing_end_time - processing_start_time) / 1000 + ) success = await try_backfill(likely_domains) if success: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 3251fca6fb..17f2fd4458 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -726,17 +726,35 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas async def get_backfill_points_in_room( self, room_id: str, + current_depth: int, + limit: int, ) -> List[Tuple[str, int]]: """ - Gets the oldest events(backwards extremities) in the room along with the - approximate depth. Sorted by depth, highest to lowest (descending). + Get the backward extremities to backfill from in the room along with the + approximate depth. + + Only returns events that are at a depth lower than or + equal to the `current_depth`. Sorted by depth, highest to lowest (descending) + so the closest events to the `current_depth` are first in the list. + + We ignore extremities that are newer than the user's current scroll position + (ie, those with depth greater than `current_depth`) as: + 1. we don't really care about getting events that have happened + after our current position; and + 2. by the nature of paginating and scrolling back, we have likely + previously tried and failed to backfill from that extremity, so + to avoid getting "stuck" requesting the same backfill repeatedly + we drop those extremities. Args: room_id: Room where we want to find the oldest events + current_depth: The depth at the user's current scrollback position + limit: The max number of backfill points to return Returns: List of (event_id, depth) tuples. Sorted by depth, highest to lowest - (descending) + (descending) so the closest events to the `current_depth` are first + in the list. """ def get_backfill_points_in_room_txn( @@ -784,6 +802,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas * necessarily safe to assume that it will have been completed. */ AND edge.is_state is ? /* False */ + /** + * We only want backwards extremities that are older than or at + * the same position of the given `current_depth` (where older + * means less than the given depth) because we're looking backwards + * from the `current_depth` when backfilling. + * + * current_depth (ignore events that come after this, ignore 2-4) + * | + * ▼ + * [0]<--[1]<--[2]<--[3]<--[4] + */ + AND event.depth <= ? /* current_depth */ /** * Exponential back-off (up to the upper bound) so we don't retry the * same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. @@ -798,11 +828,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */) ) /** - * Sort from highest to the lowest depth. Then tie-break on - * alphabetical order of the event_ids so we get a consistent - * ordering which is nice when asserting things in tests. + * Sort from highest (closest to the `current_depth`) to the lowest depth + * because the closest are most relevant to backfill from first. + * Then tie-break on alphabetical order of the event_ids so we get a + * consistent ordering which is nice when asserting things in tests. */ ORDER BY event.depth DESC, backward_extrem.event_id DESC + LIMIT ? """ if isinstance(self.database_engine, PostgresEngine): @@ -817,9 +849,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ( room_id, False, + current_depth, self._clock.time_msec(), 1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS, 1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS, + limit, ), ) @@ -835,18 +869,34 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas async def get_insertion_event_backward_extremities_in_room( self, room_id: str, + current_depth: int, + limit: int, ) -> List[Tuple[str, int]]: """ Get the insertion events we know about that we haven't backfilled yet - along with the approximate depth. Sorted by depth, highest to lowest - (descending). + along with the approximate depth. Only returns insertion events that are + at a depth lower than or equal to the `current_depth`. Sorted by depth, + highest to lowest (descending) so the closest events to the + `current_depth` are first in the list. + + We ignore insertion events that are newer than the user's current scroll + position (ie, those with depth greater than `current_depth`) as: + 1. we don't really care about getting events that have happened + after our current position; and + 2. by the nature of paginating and scrolling back, we have likely + previously tried and failed to backfill from that insertion event, so + to avoid getting "stuck" requesting the same backfill repeatedly + we drop those insertion event. Args: room_id: Room where we want to find the oldest events + current_depth: The depth at the user's current scrollback position + limit: The max number of insertion event extremities to return Returns: List of (event_id, depth) tuples. Sorted by depth, highest to lowest - (descending) + (descending) so the closest events to the `current_depth` are first + in the list. """ def get_insertion_event_backward_extremities_in_room_txn( @@ -869,6 +919,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas AND failed_backfill_attempt_info.event_id = insertion_event_extremity.event_id WHERE insertion_event_extremity.room_id = ? + /** + * We only want extremities that are older than or at + * the same position of the given `current_depth` (where older + * means less than the given depth) because we're looking backwards + * from the `current_depth` when backfilling. + * + * current_depth (ignore events that come after this, ignore 2-4) + * | + * ▼ + * [0]<--[1]<--[2]<--[3]<--[4] + */ + AND event.depth <= ? /* current_depth */ /** * Exponential back-off (up to the upper bound) so we don't retry the * same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc @@ -883,11 +945,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */) ) /** - * Sort from highest to the lowest depth. Then tie-break on - * alphabetical order of the event_ids so we get a consistent - * ordering which is nice when asserting things in tests. + * Sort from highest (closest to the `current_depth`) to the lowest depth + * because the closest are most relevant to backfill from first. + * Then tie-break on alphabetical order of the event_ids so we get a + * consistent ordering which is nice when asserting things in tests. */ ORDER BY event.depth DESC, insertion_event_extremity.event_id DESC + LIMIT ? """ if isinstance(self.database_engine, PostgresEngine): @@ -901,9 +965,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas sql % (least_function,), ( room_id, + current_depth, self._clock.time_msec(), 1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS, 1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS, + limit, ), ) return cast(List[Tuple[str, int]], txn.fetchall()) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 85739c464e..398f338b66 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -754,19 +754,31 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_backfill_points_in_room(self): """ - Test to make sure we get some backfill points + Test to make sure only backfill points that are older and come before + the `current_depth` are returned. """ setup_info = self._setup_room_for_backfill_tests() room_id = setup_info.room_id + depth_map = setup_info.depth_map + # Try at "B" backfill_points = self.get_success( - self.store.get_backfill_points_in_room(room_id) + self.store.get_backfill_points_in_room(room_id, depth_map["B"], limit=100) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] self.assertListEqual( backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"] ) + # Try at "A" + backfill_points = self.get_success( + self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + # Event "2" has a depth of 2 but is not included here because we only + # know the approximate depth of 5 from our event "3". + self.assertListEqual(backfill_event_ids, ["b3", "b2", "b1"]) + def test_get_backfill_points_in_room_excludes_events_we_have_attempted( self, ): @@ -776,6 +788,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): """ setup_info = self._setup_room_for_backfill_tests() room_id = setup_info.room_id + depth_map = setup_info.depth_map # Record some attempts to backfill these events which will make # `get_backfill_points_in_room` exclude them because we @@ -795,8 +808,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # No time has passed since we attempted to backfill ^ + # Try at "B" backfill_points = self.get_success( - self.store.get_backfill_points_in_room(room_id) + self.store.get_backfill_points_in_room(room_id, depth_map["B"], limit=100) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] # Only the backfill points that we didn't record earlier exist here. @@ -812,6 +826,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): """ setup_info = self._setup_room_for_backfill_tests() room_id = setup_info.room_id + depth_map = setup_info.depth_map # Record some attempts to backfill these events which will make # `get_backfill_points_in_room` exclude them because we @@ -839,26 +854,24 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # visible regardless. self.reactor.advance(datetime.timedelta(hours=2).total_seconds()) - # Make sure that "b1" is not in the list because we've + # Try at "A" and make sure that "b1" is not in the list because we've # already attempted many times backfill_points = self.get_success( - self.store.get_backfill_points_in_room(room_id) + self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual(backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2"]) + self.assertListEqual(backfill_event_ids, ["b3", "b2"]) # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and # see if we can now backfill it self.reactor.advance(datetime.timedelta(hours=20).total_seconds()) - # Try again after we advanced enough time and we should see "b3" again + # Try at "A" again after we advanced enough time and we should see "b3" again backfill_points = self.get_success( - self.store.get_backfill_points_in_room(room_id) + self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual( - backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"] - ) + self.assertListEqual(backfill_event_ids, ["b3", "b2", "b1"]) def _setup_room_for_insertion_backfill_tests(self) -> _BackfillSetupInfo: """ @@ -938,19 +951,35 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_insertion_event_backward_extremities_in_room(self): """ - Test to make sure insertion event backward extremities are returned. + Test to make sure only insertion event backward extremities that are + older and come before the `current_depth` are returned. """ setup_info = self._setup_room_for_insertion_backfill_tests() room_id = setup_info.room_id + depth_map = setup_info.depth_map + # Try at "insertion_eventB" backfill_points = self.get_success( - self.store.get_insertion_event_backward_extremities_in_room(room_id) + self.store.get_insertion_event_backward_extremities_in_room( + room_id, depth_map["insertion_eventB"], limit=100 + ) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] self.assertListEqual( backfill_event_ids, ["insertion_eventB", "insertion_eventA"] ) + # Try at "insertion_eventA" + backfill_points = self.get_success( + self.store.get_insertion_event_backward_extremities_in_room( + room_id, depth_map["insertion_eventA"], limit=100 + ) + ) + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + # Event "2" has a depth of 2 but is not included here because we only + # know the approximate depth of 5 from our event "3". + self.assertListEqual(backfill_event_ids, ["insertion_eventA"]) + def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted( self, ): @@ -961,6 +990,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): """ setup_info = self._setup_room_for_insertion_backfill_tests() room_id = setup_info.room_id + depth_map = setup_info.depth_map # Record some attempts to backfill these events which will make # `get_insertion_event_backward_extremities_in_room` exclude them @@ -973,8 +1003,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # No time has passed since we attempted to backfill ^ + # Try at "insertion_eventB" backfill_points = self.get_success( - self.store.get_insertion_event_backward_extremities_in_room(room_id) + self.store.get_insertion_event_backward_extremities_in_room( + room_id, depth_map["insertion_eventB"], limit=100 + ) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] # Only the backfill points that we didn't record earlier exist here. @@ -991,6 +1024,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): """ setup_info = self._setup_room_for_insertion_backfill_tests() room_id = setup_info.room_id + depth_map = setup_info.depth_map # Record some attempts to backfill these events which will make # `get_backfill_points_in_room` exclude them because we @@ -1027,13 +1061,15 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # because we haven't waited long enough for this many attempts. self.reactor.advance(datetime.timedelta(hours=2).total_seconds()) - # Make sure that "insertion_eventA" is not in the list because we've - # already attempted many times + # Try at "insertion_eventA" and make sure that "insertion_eventA" is not + # in the list because we've already attempted many times backfill_points = self.get_success( - self.store.get_insertion_event_backward_extremities_in_room(room_id) + self.store.get_insertion_event_backward_extremities_in_room( + room_id, depth_map["insertion_eventA"], limit=100 + ) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual(backfill_event_ids, ["insertion_eventB"]) + self.assertListEqual(backfill_event_ids, []) # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and # see if we can now backfill it @@ -1042,12 +1078,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # Try at "insertion_eventA" again after we advanced enough time and we # should see "insertion_eventA" again backfill_points = self.get_success( - self.store.get_insertion_event_backward_extremities_in_room(room_id) + self.store.get_insertion_event_backward_extremities_in_room( + room_id, depth_map["insertion_eventA"], limit=100 + ) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual( - backfill_event_ids, ["insertion_eventB", "insertion_eventA"] - ) + self.assertListEqual(backfill_event_ids, ["insertion_eventA"]) @attr.s -- cgit 1.5.1 From 568016929f3d22f632cb9145429fa45754a8d59f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 29 Sep 2022 07:07:31 -0400 Subject: Clarify that a method returns only unthreaded receipts. (#13937) By renaming it and updating the docstring. Additionally, refactors a method which is used only by tests. --- changelog.d/13937.feature | 1 + .../storage/databases/main/event_push_actions.py | 12 +--- synapse/storage/databases/main/receipts.py | 36 ++--------- tests/storage/test_receipts.py | 74 +++++++++++----------- 4 files changed, 47 insertions(+), 76 deletions(-) create mode 100644 changelog.d/13937.feature (limited to 'tests/storage') diff --git a/changelog.d/13937.feature b/changelog.d/13937.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13937.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index f4cdc2e399..7e0ffef7d3 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -366,14 +366,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: str, ) -> NotifCounts: # Get the stream ordering of the user's latest receipt in the room. - result = self.get_last_receipt_for_user_txn( + result = self.get_last_unthreaded_receipt_for_user_txn( txn, user_id, room_id, - receipt_types=( - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ), + receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), ) if result: @@ -574,10 +571,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas receipt_types_clause, args = make_in_list_sql_clause( self.database_engine, "receipt_type", - ( - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ), + (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), ) sql = f""" diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 52fe0db924..246f78ac1f 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -135,34 +135,7 @@ class ReceiptsWorkerStore(SQLBaseStore): """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() - async def get_last_receipt_event_id_for_user( - self, user_id: str, room_id: str, receipt_types: Collection[str] - ) -> Optional[str]: - """ - Fetch the event ID for the latest receipt in a room with one of the given receipt types. - - Args: - user_id: The user to fetch receipts for. - room_id: The room ID to fetch the receipt for. - receipt_type: The receipt types to fetch. - - Returns: - The latest receipt, if one exists. - """ - result = await self.db_pool.runInteraction( - "get_last_receipt_event_id_for_user", - self.get_last_receipt_for_user_txn, - user_id, - room_id, - receipt_types, - ) - if not result: - return None - - event_id, _ = result - return event_id - - def get_last_receipt_for_user_txn( + def get_last_unthreaded_receipt_for_user_txn( self, txn: LoggingTransaction, user_id: str, @@ -170,13 +143,13 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_types: Collection[str], ) -> Optional[Tuple[str, int]]: """ - Fetch the event ID and stream_ordering for the latest receipt in a room - with one of the given receipt types. + Fetch the event ID and stream_ordering for the latest unthreaded receipt + in a room with one of the given receipt types. Args: user_id: The user to fetch receipts for. room_id: The room ID to fetch the receipt for. - receipt_type: The receipt types to fetch. + receipt_types: The receipt types to fetch. Returns: The event ID and stream ordering of the latest receipt, if one exists. @@ -193,6 +166,7 @@ class ReceiptsWorkerStore(SQLBaseStore): WHERE {clause} AND user_id = ? AND room_id = ? + AND thread_id IS NULL ORDER BY stream_ordering DESC LIMIT 1 """ diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index 9459ee1705..81253d0361 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Collection, Optional from synapse.api.constants import ReceiptTypes from synapse.types import UserID, create_requester @@ -84,6 +85,33 @@ class ReceiptTestCase(HomeserverTestCase): ) ) + def get_last_unthreaded_receipt( + self, receipt_types: Collection[str], room_id: Optional[str] = None + ) -> Optional[str]: + """ + Fetch the event ID for the latest unthreaded receipt in the test room for the test user. + + Args: + receipt_types: The receipt types to fetch. + + Returns: + The latest receipt, if one exists. + """ + result = self.get_success( + self.store.db_pool.runInteraction( + "get_last_receipt_event_id_for_user", + self.store.get_last_unthreaded_receipt_for_user_txn, + OUR_USER_ID, + room_id or self.room_id1, + receipt_types, + ) + ) + if not result: + return None + + event_id, _ = result + return event_id + def test_return_empty_with_no_data(self) -> None: res = self.get_success( self.store.get_receipts_for_user( @@ -107,16 +135,10 @@ class ReceiptTestCase(HomeserverTestCase): ) self.assertEqual(res, {}) - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, - self.room_id1, - [ - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ], - ) + res = self.get_last_unthreaded_receipt( + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) + self.assertEqual(res, None) def test_get_receipts_for_user(self) -> None: @@ -228,29 +250,17 @@ class ReceiptTestCase(HomeserverTestCase): ) # Test we get the latest event when we want both private and public receipts - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, - self.room_id1, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], - ) + res = self.get_last_unthreaded_receipt( + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) self.assertEqual(res, event1_2_id) # Test we get the older event when we want only public receipt - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [ReceiptTypes.READ] - ) - ) + res = self.get_last_unthreaded_receipt([ReceiptTypes.READ]) self.assertEqual(res, event1_1_id) # Test we get the latest event when we want only the private receipt - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] - ) - ) + res = self.get_last_unthreaded_receipt([ReceiptTypes.READ_PRIVATE]) self.assertEqual(res, event1_2_id) # Test receipt updating @@ -259,11 +269,7 @@ class ReceiptTestCase(HomeserverTestCase): self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {} ) ) - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [ReceiptTypes.READ] - ) - ) + res = self.get_last_unthreaded_receipt([ReceiptTypes.READ]) self.assertEqual(res, event1_2_id) # Send some events into the second room @@ -282,11 +288,7 @@ class ReceiptTestCase(HomeserverTestCase): {}, ) ) - res = self.get_success( - self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, - self.room_id2, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], - ) + res = self.get_last_unthreaded_receipt( + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], room_id=self.room_id2 ) self.assertEqual(res, event2_1_id) -- cgit 1.5.1 From be76cd8200b18f3c68b895f85ac7ef5b0ddc2466 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 29 Sep 2022 14:23:24 +0100 Subject: Allow admins to require a manual approval process before new accounts can be used (using MSC3866) (#13556) --- changelog.d/13556.feature | 1 + synapse/_scripts/synapse_port_db.py | 2 +- synapse/api/constants.py | 11 ++ synapse/api/errors.py | 16 ++ synapse/config/experimental.py | 19 +++ synapse/handlers/admin.py | 5 + synapse/handlers/auth.py | 11 ++ synapse/handlers/register.py | 8 + synapse/replication/http/register.py | 5 + synapse/rest/admin/users.py | 43 ++++- synapse/rest/client/login.py | 37 +++- synapse/rest/client/register.py | 22 ++- synapse/storage/databases/main/__init__.py | 9 +- synapse/storage/databases/main/registration.py | 150 +++++++++++++++-- .../main/delta/73/03users_approved_column.sql | 20 +++ tests/rest/admin/test_user.py | 186 ++++++++++++++++++++- tests/rest/client/test_auth.py | 33 +++- tests/rest/client/test_login.py | 41 +++++ tests/rest/client/test_register.py | 32 +++- tests/rest/client/utils.py | 12 +- tests/storage/test_registration.py | 102 ++++++++++- 21 files changed, 731 insertions(+), 34 deletions(-) create mode 100644 changelog.d/13556.feature create mode 100644 synapse/storage/schema/main/delta/73/03users_approved_column.sql (limited to 'tests/storage') diff --git a/changelog.d/13556.feature b/changelog.d/13556.feature new file mode 100644 index 0000000000..f9d63db6c0 --- /dev/null +++ b/changelog.d/13556.feature @@ -0,0 +1 @@ +Allow server admins to require a manual approval process before new accounts can be used (using [MSC3866](https://github.com/matrix-org/matrix-spec-proposals/pull/3866)). diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 450ba462ba..5fa599e70e 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -107,7 +107,7 @@ BOOLEAN_COLUMNS = { "redactions": ["have_censored"], "room_stats_state": ["is_federatable"], "local_media_repository": ["safe_from_quarantine"], - "users": ["shadow_banned"], + "users": ["shadow_banned", "approved"], "e2e_fallback_keys_json": ["used"], "access_tokens": ["used"], "device_lists_changes_in_room": ["converted_to_destinations"], diff --git a/synapse/api/constants.py b/synapse/api/constants.py index c178ddf070..c031903b1a 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -269,3 +269,14 @@ class PublicRoomsFilterFields: GENERIC_SEARCH_TERM: Final = "generic_search_term" ROOM_TYPES: Final = "room_types" + + +class ApprovalNoticeMedium: + """Identifier for the medium this server will use to serve notice of approval for a + specific user's registration. + + As defined in https://github.com/matrix-org/matrix-spec-proposals/blob/babolivier/m_not_approved/proposals/3866-user-not-approved-error.md + """ + + NONE = "org.matrix.msc3866.none" + EMAIL = "org.matrix.msc3866.email" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 1c6b53aa24..c606207569 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -106,6 +106,8 @@ class Codes(str, Enum): # Part of MSC3895. UNABLE_DUE_TO_PARTIAL_STATE = "ORG.MATRIX.MSC3895_UNABLE_DUE_TO_PARTIAL_STATE" + USER_AWAITING_APPROVAL = "ORG.MATRIX.MSC3866_USER_AWAITING_APPROVAL" + class CodeMessageException(RuntimeError): """An exception with integer code and message string attributes. @@ -566,6 +568,20 @@ class UnredactedContentDeletedError(SynapseError): return cs_error(self.msg, self.errcode, **extra) +class NotApprovedError(SynapseError): + def __init__( + self, + msg: str, + approval_notice_medium: str, + ): + super().__init__( + code=403, + msg=msg, + errcode=Codes.USER_AWAITING_APPROVAL, + additional_fields={"approval_notice_medium": approval_notice_medium}, + ) + + def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict": """Utility method for constructing an error response for client-server interactions. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 933779c23a..31834fb27d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -14,10 +14,25 @@ from typing import Any +import attr + from synapse.config._base import Config from synapse.types import JsonDict +@attr.s(auto_attribs=True, frozen=True, slots=True) +class MSC3866Config: + """Configuration for MSC3866 (mandating approval for new users)""" + + # Whether the base support for the approval process is enabled. This includes the + # ability for administrators to check and update the approval of users, even if no + # approval is currently required. + enabled: bool = False + # Whether to require that new users are approved by an admin before their account + # can be used. Note that this setting is ignored if 'enabled' is false. + require_approval_for_new_accounts: bool = False + + class ExperimentalConfig(Config): """Config section for enabling experimental features""" @@ -97,6 +112,10 @@ class ExperimentalConfig(Config): # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) + # MSC3866: M_USER_AWAITING_APPROVAL error code + raw_msc3866_config = experimental.get("msc3866", {}) + self.msc3866 = MSC3866Config(**raw_msc3866_config) + # MSC3881: Remotely toggle push notifications for another client self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index cf9f19608a..f2989cc4a2 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -32,6 +32,7 @@ class AdminHandler: self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state + self._msc3866_enabled = hs.config.experimental.msc3866.enabled async def get_whois(self, user: UserID) -> JsonDict: connections = [] @@ -75,6 +76,10 @@ class AdminHandler: "is_guest", } + if self._msc3866_enabled: + # Only include the approved flag if support for MSC3866 is enabled. + user_info_to_return.add("approved") + # Restrict returned keys to a known set. user_info_dict = { key: value diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index eacd631ee0..f5f0e0e7a7 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1009,6 +1009,17 @@ class AuthHandler: return res[0] return None + async def is_user_approved(self, user_id: str) -> bool: + """Checks if a user is approved and therefore can be allowed to log in. + + Args: + user_id: the user to check the approval status of. + + Returns: + A boolean that is True if the user is approved, False otherwise. + """ + return await self.store.is_user_approved(user_id) + async def _find_user_id_and_pwd_hash( self, user_id: str ) -> Optional[Tuple[str, str]]: diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index cfcadb34db..ca1c7a1866 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -220,6 +220,7 @@ class RegistrationHandler: by_admin: bool = False, user_agent_ips: Optional[List[Tuple[str, str]]] = None, auth_provider_id: Optional[str] = None, + approved: bool = False, ) -> str: """Registers a new client on the server. @@ -246,6 +247,8 @@ class RegistrationHandler: user_agent_ips: Tuples of user-agents and IP addresses used during the registration process. auth_provider_id: The SSO IdP the user used, if any. + approved: True if the new user should be considered already + approved by an administrator. Returns: The registered user_id. Raises: @@ -307,6 +310,7 @@ class RegistrationHandler: user_type=user_type, address=address, shadow_banned=shadow_banned, + approved=approved, ) profile = await self.store.get_profileinfo(localpart) @@ -695,6 +699,7 @@ class RegistrationHandler: user_type: Optional[str] = None, address: Optional[str] = None, shadow_banned: bool = False, + approved: bool = False, ) -> None: """Register user in the datastore. @@ -713,6 +718,7 @@ class RegistrationHandler: api.constants.UserTypes, or None for a normal user. address: the IP address used to perform the registration. shadow_banned: Whether to shadow-ban the user + approved: Whether to mark the user as approved by an administrator """ if self.hs.config.worker.worker_app: await self._register_client( @@ -726,6 +732,7 @@ class RegistrationHandler: user_type=user_type, address=address, shadow_banned=shadow_banned, + approved=approved, ) else: await self.store.register_user( @@ -738,6 +745,7 @@ class RegistrationHandler: admin=admin, user_type=user_type, shadow_banned=shadow_banned, + approved=approved, ) # Only call the account validity module(s) on the main process, to avoid diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 6c8f8388fd..61abb529c8 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -51,6 +51,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type: Optional[str], address: Optional[str], shadow_banned: bool, + approved: bool, ) -> JsonDict: """ Args: @@ -68,6 +69,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint): or None for a normal user. address: the IP address used to perform the regitration. shadow_banned: Whether to shadow-ban the user + approved: Whether the user should be considered already approved by an + administrator. """ return { "password_hash": password_hash, @@ -79,6 +82,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "user_type": user_type, "address": address, "shadow_banned": shadow_banned, + "approved": approved, } async def _handle_request( # type: ignore[override] @@ -99,6 +103,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type=content["user_type"], address=content["address"], shadow_banned=content["shadow_banned"], + approved=content["approved"], ) return 200, {} diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 1274773d7e..15ac2059aa 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -69,6 +69,7 @@ class UsersRestServletV2(RestServlet): self.store = hs.get_datastores().main self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() + self._msc3866_enabled = hs.config.experimental.msc3866.enabled async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -95,6 +96,13 @@ class UsersRestServletV2(RestServlet): guests = parse_boolean(request, "guests", default=True) deactivated = parse_boolean(request, "deactivated", default=False) + # If support for MSC3866 is not enabled, apply no filtering based on the + # `approved` column. + if self._msc3866_enabled: + approved = parse_boolean(request, "approved", default=True) + else: + approved = True + order_by = parse_string( request, "order_by", @@ -115,8 +123,22 @@ class UsersRestServletV2(RestServlet): direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) users, total = await self.store.get_users_paginate( - start, limit, user_id, name, guests, deactivated, order_by, direction + start, + limit, + user_id, + name, + guests, + deactivated, + order_by, + direction, + approved, ) + + # If support for MSC3866 is not enabled, don't show the approval flag. + if not self._msc3866_enabled: + for user in users: + del user["approved"] + ret = {"users": users, "total": total} if (start + limit) < total: ret["next_token"] = str(start + len(users)) @@ -163,6 +185,7 @@ class UserRestServletV2(RestServlet): self.deactivate_account_handler = hs.get_deactivate_account_handler() self.registration_handler = hs.get_registration_handler() self.pusher_pool = hs.get_pusherpool() + self._msc3866_enabled = hs.config.experimental.msc3866.enabled async def on_GET( self, request: SynapseRequest, user_id: str @@ -239,6 +262,15 @@ class UserRestServletV2(RestServlet): HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean" ) + approved: Optional[bool] = None + if "approved" in body and self._msc3866_enabled: + approved = body["approved"] + if not isinstance(approved, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "'approved' parameter is not of type boolean", + ) + # convert List[Dict[str, str]] into List[Tuple[str, str]] if external_ids is not None: new_external_ids = [ @@ -343,6 +375,9 @@ class UserRestServletV2(RestServlet): if "user_type" in body: await self.store.set_user_type(target_user, user_type) + if approved is not None: + await self.store.update_user_approval_status(target_user, approved) + user = await self.admin_handler.get_user(target_user) assert user is not None @@ -355,6 +390,10 @@ class UserRestServletV2(RestServlet): if password is not None: password_hash = await self.auth_handler.hash(password) + new_user_approved = True + if self._msc3866_enabled and approved is not None: + new_user_approved = approved + user_id = await self.registration_handler.register_user( localpart=target_user.localpart, password_hash=password_hash, @@ -362,6 +401,7 @@ class UserRestServletV2(RestServlet): default_display_name=displayname, user_type=user_type, by_admin=True, + approved=new_user_approved, ) if threepids is not None: @@ -550,6 +590,7 @@ class UserRegisterServlet(RestServlet): user_type=user_type, default_display_name=displayname, by_admin=True, + approved=True, ) result = await register._create_registration_details(user_id, body) diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 0437c87d8d..f554586ac3 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -28,7 +28,14 @@ from typing import ( from typing_extensions import TypedDict -from synapse.api.errors import Codes, InvalidClientTokenError, LoginError, SynapseError +from synapse.api.constants import ApprovalNoticeMedium +from synapse.api.errors import ( + Codes, + InvalidClientTokenError, + LoginError, + NotApprovedError, + SynapseError, +) from synapse.api.ratelimiting import Ratelimiter from synapse.api.urls import CLIENT_API_PREFIX from synapse.appservice import ApplicationService @@ -55,11 +62,11 @@ logger = logging.getLogger(__name__) class LoginResponse(TypedDict, total=False): user_id: str - access_token: str + access_token: Optional[str] home_server: str expires_in_ms: Optional[int] refresh_token: Optional[str] - device_id: str + device_id: Optional[str] well_known: Optional[Dict[str, Any]] @@ -92,6 +99,12 @@ class LoginRestServlet(RestServlet): hs.config.registration.refreshable_access_token_lifetime is not None ) + # Whether we need to check if the user has been approved or not. + self._require_approval = ( + hs.config.experimental.msc3866.enabled + and hs.config.experimental.msc3866.require_approval_for_new_accounts + ) + self.auth = hs.get_auth() self.clock = hs.get_clock() @@ -220,6 +233,14 @@ class LoginRestServlet(RestServlet): except KeyError: raise SynapseError(400, "Missing JSON keys.") + if self._require_approval: + approved = await self.auth_handler.is_user_approved(result["user_id"]) + if not approved: + raise NotApprovedError( + msg="This account is pending approval by a server administrator.", + approval_notice_medium=ApprovalNoticeMedium.NONE, + ) + well_known_data = self._well_known_builder.get_well_known() if well_known_data: result["well_known"] = well_known_data @@ -356,6 +377,16 @@ class LoginRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) + if self._require_approval: + approved = await self.auth_handler.is_user_approved(user_id) + if not approved: + # If the user isn't approved (and needs to be) we won't allow them to + # actually log in, so we don't want to create a device/access token. + return LoginResponse( + user_id=user_id, + home_server=self.hs.hostname, + ) + initial_display_name = login_submission.get("initial_device_display_name") ( device_id, diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 20bab20c8f..de810ae3ec 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -21,10 +21,15 @@ from twisted.web.server import Request import synapse import synapse.api.auth import synapse.types -from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType +from synapse.api.constants import ( + APP_SERVICE_REGISTRATION_TYPE, + ApprovalNoticeMedium, + LoginType, +) from synapse.api.errors import ( Codes, InteractiveAuthIncompleteError, + NotApprovedError, SynapseError, ThreepidValidationError, UnrecognizedRequestError, @@ -414,6 +419,11 @@ class RegisterRestServlet(RestServlet): hs.config.registration.inhibit_user_in_use_error ) + self._require_approval = ( + hs.config.experimental.msc3866.enabled + and hs.config.experimental.msc3866.require_approval_for_new_accounts + ) + self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler ) @@ -734,6 +744,12 @@ class RegisterRestServlet(RestServlet): access_token=return_dict.get("access_token"), ) + if self._require_approval: + raise NotApprovedError( + msg="This account needs to be approved by an administrator before it can be used.", + approval_notice_medium=ApprovalNoticeMedium.NONE, + ) + return 200, return_dict async def _do_appservice_registration( @@ -778,7 +794,9 @@ class RegisterRestServlet(RestServlet): "user_id": user_id, "home_server": self.hs.hostname, } - if not params.get("inhibit_login", False): + # We don't want to log the user in if we're going to deny them access because + # they need to be approved first. + if not params.get("inhibit_login", False) and not self._require_approval: device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") ( diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 0843f10340..a62b4abd4e 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -203,6 +203,7 @@ class DataStore( deactivated: bool = False, order_by: str = UserSortOrder.USER_ID.value, direction: str = "f", + approved: bool = True, ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the @@ -217,6 +218,7 @@ class DataStore( deactivated: whether to include deactivated users order_by: the sort order of the returned list direction: sort ascending or descending + approved: whether to include approved users Returns: A tuple of a list of mappings from user to information and a count of total users. """ @@ -249,6 +251,11 @@ class DataStore( if not deactivated: filters.append("deactivated = 0") + if not approved: + # We ignore NULL values for the approved flag because these should only + # be already existing users that we consider as already approved. + filters.append("approved IS FALSE") + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" sql_base = f""" @@ -262,7 +269,7 @@ class DataStore( sql = f""" SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, - displayname, avatar_url, creation_ts * 1000 as creation_ts + displayname, avatar_url, creation_ts * 1000 as creation_ts, approved {sql_base} ORDER BY {order_by_column} {order}, u.name ASC LIMIT ? OFFSET ? diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index ac821878b0..2996d6bb4d 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -166,27 +166,49 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: """Deprecated: use get_userinfo_by_id instead""" - return await self.db_pool.simple_select_one( - table="users", - keyvalues={"name": user_id}, - retcols=[ - "name", - "password_hash", - "is_guest", - "admin", - "consent_version", - "consent_ts", - "consent_server_notice_sent", - "appservice_id", - "creation_ts", - "user_type", - "deactivated", - "shadow_banned", - ], - allow_none=True, + + def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: + # We could technically use simple_select_one here, but it would not perform + # the COALESCEs (unless hacked into the column names), which could yield + # confusing results. + txn.execute( + """ + SELECT + name, password_hash, is_guest, admin, consent_version, consent_ts, + consent_server_notice_sent, appservice_id, creation_ts, user_type, + deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned, + COALESCE(approved, TRUE) AS approved + FROM users + WHERE name = ? + """, + (user_id,), + ) + + rows = self.db_pool.cursor_to_dict(txn) + + if len(rows) == 0: + return None + + return rows[0] + + row = await self.db_pool.runInteraction( desc="get_user_by_id", + func=get_user_by_id_txn, ) + if row is not None: + # If we're using SQLite our boolean values will be integers. Because we + # present some of this data as is to e.g. server admins via REST APIs, we + # want to make sure we're returning the right type of data. + # Note: when adding a column name to this list, be wary of NULLable columns, + # since NULL values will be turned into False. + boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"] + for column in boolean_columns: + if not isinstance(row[column], bool): + row[column] = bool(row[column]) + + return row + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: """Get a UserInfo object for a user by user ID. @@ -1779,6 +1801,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return res if res else False + @cached() + async def is_user_approved(self, user_id: str) -> bool: + """Checks if a user is approved and therefore can be allowed to log in. + + If the user's 'approved' column is NULL, we consider it as true given it means + the user was registered when support for an approval flow was either disabled + or nonexistent. + + Args: + user_id: the user to check the approval status of. + + Returns: + A boolean that is True if the user is approved, False otherwise. + """ + + def is_user_approved_txn(txn: LoggingTransaction) -> bool: + txn.execute( + """ + SELECT COALESCE(approved, TRUE) AS approved FROM users WHERE name = ? + """, + (user_id,), + ) + + rows = self.db_pool.cursor_to_dict(txn) + + # We cast to bool because the value returned by the database engine might + # be an integer if we're using SQLite. + return bool(rows[0]["approved"]) + + return await self.db_pool.runInteraction( + desc="is_user_pending_approval", + func=is_user_approved_txn, + ) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__( @@ -1916,6 +1972,29 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) + def update_user_approval_status_txn( + self, txn: LoggingTransaction, user_id: str, approved: bool + ) -> None: + """Set the user's 'approved' flag to the given value. + + The boolean is turned into an int because the column is a smallint. + + Args: + txn: the current database transaction. + user_id: the user to update the flag for. + approved: the value to set the flag to. + """ + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"approved": approved}, + ) + + # Invalidate the caches of methods that read the value of the 'approved' flag. + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,)) + class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): def __init__( @@ -1933,6 +2012,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") + # If support for MSC3866 is enabled and configured to require approval for new + # account, we will create new users with an 'approved' flag set to false. + self._require_approval = ( + hs.config.experimental.msc3866.enabled + and hs.config.experimental.msc3866.require_approval_for_new_accounts + ) + async def add_access_token_to_user( self, user_id: str, @@ -2065,6 +2151,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): admin: bool = False, user_type: Optional[str] = None, shadow_banned: bool = False, + approved: bool = False, ) -> None: """Attempts to register an account. @@ -2083,6 +2170,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): or None for a normal user. shadow_banned: Whether the user is shadow-banned, i.e. they may be told their requests succeeded but we ignore them. + approved: Whether to consider the user has already been approved by an + administrator. Raises: StoreError if the user_id could not be registered. @@ -2099,6 +2188,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): admin, user_type, shadow_banned, + approved, ) def _register_user( @@ -2113,11 +2203,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): admin: bool, user_type: Optional[str], shadow_banned: bool, + approved: bool, ) -> None: user_id_obj = UserID.from_string(user_id) now = int(self._clock.time()) + user_approved = approved or not self._require_approval + try: if was_guest: # Ensure that the guest user actually exists @@ -2143,6 +2236,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "admin": 1 if admin else 0, "user_type": user_type, "shadow_banned": shadow_banned, + "approved": user_approved, }, ) else: @@ -2158,6 +2252,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "admin": 1 if admin else 0, "user_type": user_type, "shadow_banned": shadow_banned, + "approved": user_approved, }, ) @@ -2503,6 +2598,25 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): start_or_continue_validation_session_txn, ) + async def update_user_approval_status( + self, user_id: UserID, approved: bool + ) -> None: + """Set the user's 'approved' flag to the given value. + + The boolean will be turned into an int (in update_user_approval_status_txn) + because the column is a smallint. + + Args: + user_id: the user to update the flag for. + approved: the value to set the flag to. + """ + await self.db_pool.runInteraction( + "update_user_approval_status", + self.update_user_approval_status_txn, + user_id.to_string(), + approved, + ) + def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ diff --git a/synapse/storage/schema/main/delta/73/03users_approved_column.sql b/synapse/storage/schema/main/delta/73/03users_approved_column.sql new file mode 100644 index 0000000000..5328d592ea --- /dev/null +++ b/synapse/storage/schema/main/delta/73/03users_approved_column.sql @@ -0,0 +1,20 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a column to the users table to track whether the user needs to be approved by an +-- administrator. +-- A NULL column means the user was created before this feature was supported by Synapse, +-- and should be considered as TRUE. +ALTER TABLE users ADD COLUMN approved BOOLEAN; diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 1847e6ad6b..4c1ce33463 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -25,10 +25,10 @@ from parameterized import parameterized, parameterized_class from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import UserTypes +from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions -from synapse.rest.client import devices, login, logout, profile, room, sync +from synapse.rest.client import devices, login, logout, profile, register, room, sync from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.server import HomeServer from synapse.types import JsonDict, UserID @@ -578,6 +578,16 @@ class UsersListTestCase(unittest.HomeserverTestCase): _search_test(None, "foo", "user_id") _search_test(None, "bar", "user_id") + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. @@ -623,6 +633,16 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + # invalid approved + channel = self.make_request( + "GET", + self.url + "?approved=not_bool", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + # unkown order_by channel = self.make_request( "GET", @@ -841,6 +861,69 @@ class UsersListTestCase(unittest.HomeserverTestCase): self._order_test([self.admin_user, user1, user2], "creation_ts", "f") self._order_test([user2, user1, self.admin_user], "creation_ts", "b") + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_filter_out_approved(self) -> None: + """Tests that the endpoint can filter out approved users.""" + # Create our users. + self._create_users(2) + + # Get the list of users. + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, channel.result) + + # Exclude the admin, because we don't want to accidentally un-approve the admin. + non_admin_user_ids = [ + user["name"] + for user in channel.json_body["users"] + if user["name"] != self.admin_user + ] + + self.assertEqual(2, len(non_admin_user_ids), non_admin_user_ids) + + # Select a user and un-approve them. We do this rather than the other way around + # because, since these users are created by an admin, we consider them already + # approved. + not_approved_user = non_admin_user_ids[0] + + channel = self.make_request( + "PUT", + f"/_synapse/admin/v2/users/{not_approved_user}", + {"approved": False}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, channel.result) + + # Now get the list of users again, this time filtering out approved users. + channel = self.make_request( + "GET", + self.url + "?approved=false", + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, channel.result) + + non_admin_user_ids = [ + user["name"] + for user in channel.json_body["users"] + if user["name"] != self.admin_user + ] + + # We should only have our unapproved user now. + self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids) + self.assertEqual(not_approved_user, non_admin_user_ids[0]) + def _order_test( self, expected_user_list: List[str], @@ -1272,6 +1355,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets, login.register_servlets, sync.register_servlets, + register.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -2536,6 +2620,104 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Ensure they're still alive self.assertEqual(0, channel.json_body["deactivated"]) + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_approve_account(self) -> None: + """Tests that approving an account correctly sets the approved flag for the user.""" + url = self.url_prefix % "@bob:test" + + # Create the user using the client-server API since otherwise the user will be + # marked as approved automatically. + channel = self.make_request( + "POST", + "register", + { + "username": "bob", + "password": "test", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + self.assertEqual( + ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] + ) + + # Get user + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIs(False, channel.json_body["approved"]) + + # Approve user + channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content={"approved": True}, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIs(True, channel.json_body["approved"]) + + # Check that the user is now approved + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIs(True, channel.json_body["approved"]) + + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_register_approved(self) -> None: + url = self.url_prefix % "@bob:test" + + # Create user + channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content={"password": "abc123", "approved": True}, + ) + + self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual(1, channel.json_body["approved"]) + + # Get user + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual(1, channel.json_body["approved"]) + def _is_erased(self, user_id: str, expect: bool) -> None: """Assert that the user is erased or not""" d = self.store.is_user_erased(user_id) diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 05355c7fb6..090cef5216 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -20,7 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin -from synapse.api.constants import LoginType +from synapse.api.constants import ApprovalNoticeMedium, LoginType +from synapse.api.errors import Codes from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.synapse.client import build_synapse_client_resource_tree @@ -567,6 +568,36 @@ class UIAuthTests(unittest.HomeserverTestCase): body={"auth": {"session": session_id}}, ) + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config( + { + "oidc_config": TEST_OIDC_CONFIG, + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + }, + } + ) + def test_sso_not_approved(self) -> None: + """Tests that if we register a user via SSO while requiring approval for new + accounts, we still raise the correct error before logging the user in. + """ + login_resp = self.helper.login_via_oidc("username", expected_status=403) + + self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL) + self.assertEqual( + ApprovalNoticeMedium.NONE, login_resp["approval_notice_medium"] + ) + + # Check that we didn't register a device for the user during the login attempt. + devices = self.get_success( + self.hs.get_datastores().main.get_devices_by_user("@username:test") + ) + + self.assertEqual(len(devices), 0) + class RefreshAuthTests(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index e2a4d98275..e801ba8c8b 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -23,6 +23,8 @@ from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin +from synapse.api.constants import ApprovalNoticeMedium, LoginType +from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import devices, login, logout, register from synapse.rest.client.account import WhoamiRestServlet @@ -94,6 +96,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): logout.register_servlets, devices.register_servlets, lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), + register.register_servlets, ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: @@ -406,6 +409,44 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400) self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM") + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_require_approval(self) -> None: + channel = self.make_request( + "POST", + "register", + { + "username": "kermit", + "password": "monkey", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + self.assertEqual( + ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] + ) + + params = { + "type": LoginType.PASSWORD, + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "monkey", + } + channel = self.make_request("POST", LOGIN_URL, params) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + self.assertEqual( + ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] + ) + @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") class MultiSSOTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index b781875d52..11cf3939d8 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -22,7 +22,11 @@ import pkg_resources from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType +from synapse.api.constants import ( + APP_SERVICE_REGISTRATION_TYPE, + ApprovalNoticeMedium, + LoginType, +) from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import account, account_validity, login, logout, register, sync @@ -765,6 +769,32 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE) + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_require_approval(self) -> None: + channel = self.make_request( + "POST", + "register", + { + "username": "kermit", + "password": "monkey", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + self.assertEqual( + ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] + ) + class AccountValidityTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index dd26145bf8..c249a42bb6 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -543,8 +543,12 @@ class RestHelper: return channel.json_body - def login_via_oidc(self, remote_user_id: str) -> JsonDict: - """Log in (as a new user) via OIDC + def login_via_oidc( + self, + remote_user_id: str, + expected_status: int = 200, + ) -> JsonDict: + """Log in via OIDC Returns the result of the final token login. @@ -578,7 +582,9 @@ class RestHelper: "/login", content={"type": "m.login.token", "token": login_token}, ) - assert channel.code == HTTPStatus.OK + assert ( + channel.code == expected_status + ), f"unexpected status in response: {channel.code}" return channel.json_body def auth_via_oidc( diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 853a93afab..05ea802008 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -16,9 +16,10 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import UserTypes from synapse.api.errors import ThreepidValidationError from synapse.server import HomeServer +from synapse.types import JsonDict, UserID from synapse.util import Clock -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class RegistrationStoreTestCase(HomeserverTestCase): @@ -48,6 +49,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): "user_type": None, "deactivated": 0, "shadow_banned": 0, + "approved": 1, }, (self.get_success(self.store.get_user_by_id(self.user_id))), ) @@ -166,3 +168,101 @@ class RegistrationStoreTestCase(HomeserverTestCase): ThreepidValidationError, ) self.assertEqual(e.value.msg, "Validation token not found or has expired", e) + + +class ApprovalRequiredRegistrationTestCase(HomeserverTestCase): + def default_config(self) -> JsonDict: + config = super().default_config() + + # If there's already some config for this feature in the default config, it + # means we're overriding it with @override_config. In this case we don't want + # to do anything more with it. + msc3866_config = config.get("experimental_features", {}).get("msc3866") + if msc3866_config is not None: + return config + + # Require approval for all new accounts. + config["experimental_features"] = { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.user_id = "@my-user:test" + self.pwhash = "{xx1}123456789" + + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": False, + } + } + } + ) + def test_approval_not_required(self) -> None: + """Tests that if we don't require approval for new accounts, newly created + accounts are automatically marked as approved. + """ + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + + user = self.get_success(self.store.get_user_by_id(self.user_id)) + assert user is not None + self.assertTrue(user["approved"]) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertTrue(approved) + + def test_approval_required(self) -> None: + """Tests that if we require approval for new accounts, newly created accounts + are not automatically marked as approved. + """ + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + + user = self.get_success(self.store.get_user_by_id(self.user_id)) + assert user is not None + self.assertFalse(user["approved"]) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertFalse(approved) + + def test_override(self) -> None: + """Tests that if we require approval for new accounts, but we explicitly say the + new user should be considered approved, they're marked as approved. + """ + self.get_success( + self.store.register_user( + self.user_id, + self.pwhash, + approved=True, + ) + ) + + user = self.get_success(self.store.get_user_by_id(self.user_id)) + self.assertIsNotNone(user) + assert user is not None + self.assertEqual(user["approved"], 1) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertTrue(approved) + + def test_approve_user(self) -> None: + """Tests that approving the user updates their approval status.""" + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertFalse(approved) + + self.get_success( + self.store.update_user_approval_status( + UserID.from_string(self.user_id), True + ) + ) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertTrue(approved) -- cgit 1.5.1 From e8f30a76caa4394ebb3e77c56df951e3626b3fdd Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 30 Sep 2022 11:54:53 +0100 Subject: Fix overflows in /messages backfill calculation (#13936) * Reproduce bug * Compute `least_function` first * Substitute `least_function` with an f-string * Bugfix: avoid overflow Co-authored-by: Eric Eastwood --- changelog.d/13936.feature | 1 + synapse/storage/databases/main/event_federation.py | 82 ++++++++++++++-------- tests/storage/test_event_federation.py | 61 ++++++++++++---- 3 files changed, 103 insertions(+), 41 deletions(-) create mode 100644 changelog.d/13936.feature (limited to 'tests/storage') diff --git a/changelog.d/13936.feature b/changelog.d/13936.feature new file mode 100644 index 0000000000..d86bf7ed80 --- /dev/null +++ b/changelog.d/13936.feature @@ -0,0 +1 @@ +Exponentially backoff from backfilling the same event over and over. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 17f2fd4458..6b9a629edd 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -73,13 +73,30 @@ pdus_pruned_from_federation_queue = Counter( logger = logging.getLogger(__name__) -BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS: int = int( - datetime.timedelta(days=7).total_seconds() -) -BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS: int = int( - datetime.timedelta(hours=1).total_seconds() +# Parameters controlling exponential backoff between backfill failures. +# After the first failure to backfill, we wait 2 hours before trying again. If the +# second attempt fails, we wait 4 hours before trying again. If the third attempt fails, +# we wait 8 hours before trying again, ... and so on. +# +# Each successive backoff period is twice as long as the last. However we cap this +# period at a maximum of 2^8 = 256 hours: a little over 10 days. (This is the smallest +# power of 2 which yields a maximum backoff period of at least 7 days---which was the +# original maximum backoff period.) Even when we hit this cap, we will continue to +# make backfill attempts once every 10 days. +BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS = 8 +BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS = int( + datetime.timedelta(hours=1).total_seconds() * 1000 ) +# We need a cap on the power of 2 or else the backoff period +# 2^N * (milliseconds per hour) +# will overflow when calcuated within the database. We ensure overflow does not occur +# by checking that the largest backoff period fits in a 32-bit signed integer. +_LONGEST_BACKOFF_PERIOD_MILLISECONDS = ( + 2**BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS +) * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS +assert 0 < _LONGEST_BACKOFF_PERIOD_MILLISECONDS <= ((2**31) - 1) + # All the info we need while iterating the DAG while backfilling @attr.s(frozen=True, slots=True, auto_attribs=True) @@ -767,7 +784,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # persisted in our database yet (meaning we don't know their depth # specifically). So we need to look for the approximate depth from # the events connected to the current backwards extremeties. - sql = """ + + if isinstance(self.database_engine, PostgresEngine): + least_function = "LEAST" + elif isinstance(self.database_engine, Sqlite3Engine): + least_function = "MIN" + else: + raise RuntimeError("Unknown database engine") + + sql = f""" SELECT backward_extrem.event_id, event.depth FROM events AS event /** * Get the edge connections from the event_edges table @@ -825,7 +850,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas */ AND ( failed_backfill_attempt_info.event_id IS NULL - OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */) + OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + ( + (1 << {least_function}(failed_backfill_attempt_info.num_attempts, ? /* max doubling steps */)) + * ? /* step */ + ) ) /** * Sort from highest (closest to the `current_depth`) to the lowest depth @@ -837,22 +865,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas LIMIT ? """ - if isinstance(self.database_engine, PostgresEngine): - least_function = "least" - elif isinstance(self.database_engine, Sqlite3Engine): - least_function = "min" - else: - raise RuntimeError("Unknown database engine") - txn.execute( - sql % (least_function,), + sql, ( room_id, False, current_depth, self._clock.time_msec(), - 1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS, - 1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS, + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS, limit, ), ) @@ -902,7 +923,14 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas def get_insertion_event_backward_extremities_in_room_txn( txn: LoggingTransaction, room_id: str ) -> List[Tuple[str, int]]: - sql = """ + if isinstance(self.database_engine, PostgresEngine): + least_function = "LEAST" + elif isinstance(self.database_engine, Sqlite3Engine): + least_function = "MIN" + else: + raise RuntimeError("Unknown database engine") + + sql = f""" SELECT insertion_event_extremity.event_id, event.depth /* We only want insertion events that are also marked as backwards extremities */ @@ -942,7 +970,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas */ AND ( failed_backfill_attempt_info.event_id IS NULL - OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + /*least*/%s((1 << failed_backfill_attempt_info.num_attempts) * ? /* step */, ? /* upper bound */) + OR ? /* current_time */ >= failed_backfill_attempt_info.last_attempt_ts + ( + (1 << {least_function}(failed_backfill_attempt_info.num_attempts, ? /* max doubling steps */)) + * ? /* step */ + ) ) /** * Sort from highest (closest to the `current_depth`) to the lowest depth @@ -954,21 +985,14 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas LIMIT ? """ - if isinstance(self.database_engine, PostgresEngine): - least_function = "least" - elif isinstance(self.database_engine, Sqlite3Engine): - least_function = "min" - else: - raise RuntimeError("Unknown database engine") - txn.execute( - sql % (least_function,), + sql, ( room_id, current_depth, self._clock.time_msec(), - 1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS, - 1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS, + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS, limit, ), ) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 398f338b66..59b8910907 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -766,9 +766,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.store.get_backfill_points_in_room(room_id, depth_map["B"], limit=100) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual( - backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"] - ) + self.assertEqual(backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"]) # Try at "A" backfill_points = self.get_success( @@ -814,7 +812,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] # Only the backfill points that we didn't record earlier exist here. - self.assertListEqual(backfill_event_ids, ["b6", "2", "b1"]) + self.assertEqual(backfill_event_ids, ["b6", "2", "b1"]) def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration( self, @@ -860,7 +858,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual(backfill_event_ids, ["b3", "b2"]) + self.assertEqual(backfill_event_ids, ["b3", "b2"]) # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and # see if we can now backfill it @@ -871,7 +869,48 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual(backfill_event_ids, ["b3", "b2", "b1"]) + self.assertEqual(backfill_event_ids, ["b3", "b2", "b1"]) + + def test_get_backfill_points_in_room_works_after_many_failed_pull_attempts_that_could_naively_overflow( + self, + ) -> None: + """ + A test that reproduces #13929 (Postgres only). + + Test to make sure we can still get backfill points after many failed pull + attempts that cause us to backoff to the limit. Even if the backoff formula + would tell us to wait for more seconds than can be expressed in a 32 bit + signed int. + """ + setup_info = self._setup_room_for_backfill_tests() + room_id = setup_info.room_id + depth_map = setup_info.depth_map + + # Pretend that we have tried and failed 10 times to backfill event b1. + for _ in range(10): + self.get_success( + self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause") + ) + + # If the backoff periods grow without limit: + # After the first failed attempt, we would have backed off for 1 << 1 = 2 hours. + # After the second failed attempt we would have backed off for 1 << 2 = 4 hours, + # so after the 10th failed attempt we should backoff for 1 << 10 == 1024 hours. + # Wait 1100 hours just so we have a nice round number. + self.reactor.advance(datetime.timedelta(hours=1100).total_seconds()) + + # 1024 hours in milliseconds is 1024 * 3600000, which exceeds the largest 32 bit + # signed integer. The bug we're reproducing is that this overflow causes an + # error in postgres preventing us from fetching a set of backwards extremities + # to retry fetching. + backfill_points = self.get_success( + self.store.get_backfill_points_in_room(room_id, depth_map["A"], limit=100) + ) + + # We should aim to fetch all backoff points: b1's latest backoff period has + # expired, and we haven't tried the rest. + backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] + self.assertEqual(backfill_event_ids, ["b3", "b2", "b1"]) def _setup_room_for_insertion_backfill_tests(self) -> _BackfillSetupInfo: """ @@ -965,9 +1004,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual( - backfill_event_ids, ["insertion_eventB", "insertion_eventA"] - ) + self.assertEqual(backfill_event_ids, ["insertion_eventB", "insertion_eventA"]) # Try at "insertion_eventA" backfill_points = self.get_success( @@ -1011,7 +1048,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] # Only the backfill points that we didn't record earlier exist here. - self.assertListEqual(backfill_event_ids, ["insertion_eventB"]) + self.assertEqual(backfill_event_ids, ["insertion_eventB"]) def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration( self, @@ -1069,7 +1106,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual(backfill_event_ids, []) + self.assertEqual(backfill_event_ids, []) # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and # see if we can now backfill it @@ -1083,7 +1120,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) ) backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertListEqual(backfill_event_ids, ["insertion_eventA"]) + self.assertEqual(backfill_event_ids, ["insertion_eventA"]) @attr.s -- cgit 1.5.1 From 6d543d6d9f56e39199b7e460d0081b02d61f12be Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 30 Sep 2022 16:34:47 +0100 Subject: Update mypy and mypy-zope (#13925) * Update mypy and mypy-zope * Unignore assigning to LogRecord attributes Presumably https://github.com/python/typeshed/pull/8064 makes this ok Cherry-picked from #13521 * Remove unused ignores due to mypy ParamSpec fixes https://github.com/python/mypy/pull/12668 Cherry-picked from #13521 * Remove additional unused ignores * Fix new mypy complaints related to `assertGreater` Presumably due to https://github.com/python/typeshed/pull/8077 * Changelog * Reword changelog Co-authored-by: Patrick Cloke Co-authored-by: Patrick Cloke --- changelog.d/13925.misc | 1 + poetry.lock | 59 +++++++++++++++--------------- scripts-dev/check_pydantic_models.py | 5 +-- synapse/app/_base.py | 4 +- synapse/logging/context.py | 20 +++++----- synapse/logging/opentracing.py | 4 +- synapse/storage/database.py | 22 +++-------- synapse/storage/databases/main/search.py | 2 +- tests/storage/test_monthly_active_users.py | 6 +++ tests/utils.py | 4 +- 10 files changed, 60 insertions(+), 67 deletions(-) create mode 100644 changelog.d/13925.misc (limited to 'tests/storage') diff --git a/changelog.d/13925.misc b/changelog.d/13925.misc new file mode 100644 index 0000000000..f490ab122e --- /dev/null +++ b/changelog.d/13925.misc @@ -0,0 +1 @@ +Update mypy (0.950 -> 0.981) and mypy-zope (0.3.7 -> 0.3.11). diff --git a/poetry.lock b/poetry.lock index 0f6d1cfa69..63ef8573a0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -573,11 +573,11 @@ python-versions = "*" [[package]] name = "mypy" -version = "0.950" +version = "0.981" description = "Optional static typing for Python" category = "dev" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] mypy-extensions = ">=0.4.3" @@ -600,14 +600,14 @@ python-versions = "*" [[package]] name = "mypy-zope" -version = "0.3.7" +version = "0.3.11" description = "Plugin for mypy to support zope interfaces" category = "dev" optional = false python-versions = "*" [package.dependencies] -mypy = "0.950" +mypy = "0.981" "zope.interface" = "*" "zope.schema" = "*" @@ -2162,37 +2162,38 @@ msgpack = [ {file = "msgpack-1.0.3.tar.gz", hash = "sha256:51fdc7fb93615286428ee7758cecc2f374d5ff363bdd884c7ea622a7a327a81e"}, ] mypy = [ - {file = "mypy-0.950-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cf9c261958a769a3bd38c3e133801ebcd284ffb734ea12d01457cb09eacf7d7b"}, - {file = "mypy-0.950-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5b5bd0ffb11b4aba2bb6d31b8643902c48f990cc92fda4e21afac658044f0c0"}, - {file = "mypy-0.950-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5e7647df0f8fc947388e6251d728189cfadb3b1e558407f93254e35abc026e22"}, - {file = "mypy-0.950-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:eaff8156016487c1af5ffa5304c3e3fd183edcb412f3e9c72db349faf3f6e0eb"}, - {file = "mypy-0.950-cp310-cp310-win_amd64.whl", hash = "sha256:563514c7dc504698fb66bb1cf897657a173a496406f1866afae73ab5b3cdb334"}, - {file = "mypy-0.950-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:dd4d670eee9610bf61c25c940e9ade2d0ed05eb44227275cce88701fee014b1f"}, - {file = "mypy-0.950-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ca75ecf2783395ca3016a5e455cb322ba26b6d33b4b413fcdedfc632e67941dc"}, - {file = "mypy-0.950-cp36-cp36m-win_amd64.whl", hash = "sha256:6003de687c13196e8a1243a5e4bcce617d79b88f83ee6625437e335d89dfebe2"}, - {file = "mypy-0.950-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4c653e4846f287051599ed8f4b3c044b80e540e88feec76b11044ddc5612ffed"}, - {file = "mypy-0.950-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e19736af56947addedce4674c0971e5dceef1b5ec7d667fe86bcd2b07f8f9075"}, - {file = "mypy-0.950-cp37-cp37m-win_amd64.whl", hash = "sha256:ef7beb2a3582eb7a9f37beaf38a28acfd801988cde688760aea9e6cc4832b10b"}, - {file = "mypy-0.950-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0112752a6ff07230f9ec2f71b0d3d4e088a910fdce454fdb6553e83ed0eced7d"}, - {file = "mypy-0.950-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ee0a36edd332ed2c5208565ae6e3a7afc0eabb53f5327e281f2ef03a6bc7687a"}, - {file = "mypy-0.950-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:77423570c04aca807508a492037abbd72b12a1fb25a385847d191cd50b2c9605"}, - {file = "mypy-0.950-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5ce6a09042b6da16d773d2110e44f169683d8cc8687e79ec6d1181a72cb028d2"}, - {file = "mypy-0.950-cp38-cp38-win_amd64.whl", hash = "sha256:5b231afd6a6e951381b9ef09a1223b1feabe13625388db48a8690f8daa9b71ff"}, - {file = "mypy-0.950-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0384d9f3af49837baa92f559d3fa673e6d2652a16550a9ee07fc08c736f5e6f8"}, - {file = "mypy-0.950-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1fdeb0a0f64f2a874a4c1f5271f06e40e1e9779bf55f9567f149466fc7a55038"}, - {file = "mypy-0.950-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:61504b9a5ae166ba5ecfed9e93357fd51aa693d3d434b582a925338a2ff57fd2"}, - {file = "mypy-0.950-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a952b8bc0ae278fc6316e6384f67bb9a396eb30aced6ad034d3a76120ebcc519"}, - {file = "mypy-0.950-cp39-cp39-win_amd64.whl", hash = "sha256:eaea21d150fb26d7b4856766e7addcf929119dd19fc832b22e71d942835201ef"}, - {file = "mypy-0.950-py3-none-any.whl", hash = "sha256:a4d9898f46446bfb6405383b57b96737dcfd0a7f25b748e78ef3e8c576bba3cb"}, - {file = "mypy-0.950.tar.gz", hash = "sha256:1b333cfbca1762ff15808a0ef4f71b5d3eed8528b23ea1c3fb50543c867d68de"}, + {file = "mypy-0.981-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4bc460e43b7785f78862dab78674e62ec3cd523485baecfdf81a555ed29ecfa0"}, + {file = "mypy-0.981-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:756fad8b263b3ba39e4e204ee53042671b660c36c9017412b43af210ddee7b08"}, + {file = "mypy-0.981-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a16a0145d6d7d00fbede2da3a3096dcc9ecea091adfa8da48fa6a7b75d35562d"}, + {file = "mypy-0.981-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce65f70b14a21fdac84c294cde75e6dbdabbcff22975335e20827b3b94bdbf49"}, + {file = "mypy-0.981-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e35d764784b42c3e256848fb8ed1d4292c9fc0098413adb28d84974c095b279"}, + {file = "mypy-0.981-cp310-cp310-win_amd64.whl", hash = "sha256:e53773073c864d5f5cec7f3fc72fbbcef65410cde8cc18d4f7242dea60dac52e"}, + {file = "mypy-0.981-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6ee196b1d10b8b215e835f438e06965d7a480f6fe016eddbc285f13955cca659"}, + {file = "mypy-0.981-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ad21d4c9d3673726cf986ea1d0c9fb66905258709550ddf7944c8f885f208be"}, + {file = "mypy-0.981-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d1debb09043e1f5ee845fa1e96d180e89115b30e47c5d3ce53bc967bab53f62d"}, + {file = "mypy-0.981-cp37-cp37m-win_amd64.whl", hash = "sha256:9f362470a3480165c4c6151786b5379351b790d56952005be18bdbdd4c7ce0ae"}, + {file = "mypy-0.981-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c9e0efb95ed6ca1654951bd5ec2f3fa91b295d78bf6527e026529d4aaa1e0c30"}, + {file = "mypy-0.981-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e178eaffc3c5cd211a87965c8c0df6da91ed7d258b5fc72b8e047c3771317ddb"}, + {file = "mypy-0.981-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:06e1eac8d99bd404ed8dd34ca29673c4346e76dd8e612ea507763dccd7e13c7a"}, + {file = "mypy-0.981-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa38f82f53e1e7beb45557ff167c177802ba7b387ad017eab1663d567017c8ee"}, + {file = "mypy-0.981-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:64e1f6af81c003f85f0dfed52db632817dabb51b65c0318ffbf5ff51995bbb08"}, + {file = "mypy-0.981-cp38-cp38-win_amd64.whl", hash = "sha256:e1acf62a8c4f7c092462c738aa2c2489e275ed386320c10b2e9bff31f6f7e8d6"}, + {file = "mypy-0.981-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b6ede64e52257931315826fdbfc6ea878d89a965580d1a65638ef77cb551f56d"}, + {file = "mypy-0.981-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eb3978b191b9fa0488524bb4ffedf2c573340e8c2b4206fc191d44c7093abfb7"}, + {file = "mypy-0.981-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77f8fcf7b4b3cc0c74fb33ae54a4cd00bb854d65645c48beccf65fa10b17882c"}, + {file = "mypy-0.981-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f64d2ce043a209a297df322eb4054dfbaa9de9e8738291706eaafda81ab2b362"}, + {file = "mypy-0.981-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2ee3dbc53d4df7e6e3b1c68ac6a971d3a4fb2852bf10a05fda228721dd44fae1"}, + {file = "mypy-0.981-cp39-cp39-win_amd64.whl", hash = "sha256:8e8e49aa9cc23aa4c926dc200ce32959d3501c4905147a66ce032f05cb5ecb92"}, + {file = "mypy-0.981-py3-none-any.whl", hash = "sha256:794f385653e2b749387a42afb1e14c2135e18daeb027e0d97162e4b7031210f8"}, + {file = "mypy-0.981.tar.gz", hash = "sha256:ad77c13037d3402fbeffda07d51e3f228ba078d1c7096a73759c9419ea031bf4"}, ] mypy-extensions = [ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, ] mypy-zope = [ - {file = "mypy-zope-0.3.7.tar.gz", hash = "sha256:9da171e78e8ef7ac8922c86af1a62f1b7f3244f121020bd94a2246bc3f33c605"}, - {file = "mypy_zope-0.3.7-py3-none-any.whl", hash = "sha256:9c7637d066e4d1bafa0651abc091c752009769098043b236446e6725be2bc9c2"}, + {file = "mypy-zope-0.3.11.tar.gz", hash = "sha256:d4255f9f04d48c79083bbd4e2fea06513a6ac7b8de06f8c4ce563fd85142ca05"}, + {file = "mypy_zope-0.3.11-py3-none-any.whl", hash = "sha256:ec080a6508d1f7805c8d2054f9fdd13c849742ce96803519e1fdfa3d3cab7140"}, ] netaddr = [ {file = "netaddr-0.8.0-py2.py3-none-any.whl", hash = "sha256:9666d0232c32d2656e5e5f8d735f58fd6c7457ce52fc21c98d45f2af78f990ac"}, diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py index d0fb811bdb..9f2b7ded5b 100755 --- a/scripts-dev/check_pydantic_models.py +++ b/scripts-dev/check_pydantic_models.py @@ -88,10 +88,9 @@ def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]: @functools.wraps(factory) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - # type-ignore: should be redundant once we can use https://github.com/python/mypy/pull/12668 - if "strict" not in kwargs: # type: ignore[attr-defined] + if "strict" not in kwargs: raise MissingStrictInConstrainedTypeException(factory.__name__) - if not kwargs["strict"]: # type: ignore[index] + if not kwargs["strict"]: raise MissingStrictInConstrainedTypeException(factory.__name__) return factory(*args, **kwargs) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 9a24bed0a0..000912e86e 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -98,9 +98,7 @@ def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs) func: Function to be called when sent a SIGHUP signal. *args, **kwargs: args and kwargs to be passed to the target function. """ - # This type-ignore should be redundant once we use a mypy release with - # https://github.com/python/mypy/pull/12668. - _sighup_callbacks.append((func, args, kwargs)) # type: ignore[arg-type] + _sighup_callbacks.append((func, args, kwargs)) def start_worker_reactor( diff --git a/synapse/logging/context.py b/synapse/logging/context.py index fd9cb97920..6a08ffed64 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -586,7 +586,7 @@ class LoggingContextFilter(logging.Filter): True to include the record in the log output. """ context = current_context() - record.request = self._default_request # type: ignore + record.request = self._default_request # context should never be None, but if it somehow ends up being, then # we end up in a death spiral of infinite loops, so let's check, for @@ -594,21 +594,21 @@ class LoggingContextFilter(logging.Filter): if context is not None: # Logging is interested in the request ID. Note that for backwards # compatibility this is stored as the "request" on the record. - record.request = str(context) # type: ignore + record.request = str(context) # Add some data from the HTTP request. request = context.request if request is None: return True - record.ip_address = request.ip_address # type: ignore - record.site_tag = request.site_tag # type: ignore - record.requester = request.requester # type: ignore - record.authenticated_entity = request.authenticated_entity # type: ignore - record.method = request.method # type: ignore - record.url = request.url # type: ignore - record.protocol = request.protocol # type: ignore - record.user_agent = request.user_agent # type: ignore + record.ip_address = request.ip_address + record.site_tag = request.site_tag + record.requester = request.requester + record.authenticated_entity = request.authenticated_entity + record.method = request.method + record.url = request.url + record.protocol = request.protocol + record.user_agent = request.user_agent return True diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index ca2735dd6d..8ce5a2a338 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -992,9 +992,9 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: # FIXME: We could update this to handle any type of function by ignoring the # first argument only if it's named `self` or `cls`. This isn't fool-proof # but handles the idiomatic cases. - for i, arg in enumerate(args[1:], start=1): # type: ignore[index] + for i, arg in enumerate(args[1:], start=1): set_tag(SynapseTags.FUNC_ARG_PREFIX + argspec.args[i], str(arg)) - set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :])) # type: ignore[index] + set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :])) set_tag(SynapseTags.FUNC_KWARGS, str(kwargs)) yield diff --git a/synapse/storage/database.py b/synapse/storage/database.py index bb28ded1b5..a252f8eaa0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -290,8 +290,7 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.after_callbacks is not None - # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 - self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] + self.after_callbacks.append((callback, args, kwargs)) def async_call_after( self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs @@ -312,8 +311,7 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.async_after_callbacks is not None - # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 - self.async_after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] + self.async_after_callbacks.append((callback, args, kwargs)) def call_on_exception( self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs @@ -331,8 +329,7 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.exception_callbacks is not None - # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 - self.exception_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] + self.exception_callbacks.append((callback, args, kwargs)) def fetchone(self) -> Optional[Tuple]: return self.txn.fetchone() @@ -421,10 +418,7 @@ class LoggingTransaction: sql = self.database_engine.convert_param_style(sql) if args: try: - # The type-ignore should be redundant once mypy releases a version with - # https://github.com/python/mypy/pull/12668. (`args` might be empty, - # (but we'll catch the index error if so.) - sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) # type: ignore[index] + sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) except Exception: # Don't let logging failures stop SQL from working pass @@ -655,9 +649,7 @@ class DatabasePool: # For now, we just log an error, and hope that it works on the first attempt. # TODO: raise an exception. - # Type-ignore Mypy doesn't yet consider ParamSpec.args to be iterable; see - # https://github.com/python/mypy/pull/12668 - for i, arg in enumerate(args): # type: ignore[arg-type, var-annotated] + for i, arg in enumerate(args): if inspect.isgenerator(arg): logger.error( "Programming error: generator passed to new_transaction as " @@ -665,9 +657,7 @@ class DatabasePool: i, func, ) - # Type-ignore Mypy doesn't yet consider ParamSpec.args to be a mapping; see - # https://github.com/python/mypy/pull/12668 - for name, val in kwargs.items(): # type: ignore[attr-defined] + for name, val in kwargs.items(): if inspect.isgenerator(val): logger.error( "Programming error: generator passed to new_transaction as " diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index f6e24b68d2..1b79acf955 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -641,7 +641,7 @@ class SearchStore(SearchBackgroundUpdateStore): raise Exception("Unrecognized database engine") # mypy expects to append only a `str`, not an `int` - args.append(limit) # type: ignore[arg-type] + args.append(limit) results = await self.db_pool.execute( "search_rooms", self.db_pool.cursor_to_dict, sql, *args diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index e8b4a5644b..3da8221109 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -96,8 +96,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # Test each of the registered users is marked as active timestamp = self.get_success(self.store.user_last_seen_monthly_active(user1)) + # Mypy notes that one shouldn't compare Optional[int] to 0 with assertGreater. + # Check that timestamp really is an int. + assert timestamp is not None self.assertGreater(timestamp, 0) timestamp = self.get_success(self.store.user_last_seen_monthly_active(user2)) + assert timestamp is not None self.assertGreater(timestamp, 0) # Test that users with reserved 3pids are not removed from the MAU table @@ -166,9 +170,11 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(self.store.upsert_monthly_active_user(user_id2)) result = self.get_success(self.store.user_last_seen_monthly_active(user_id1)) + assert result is not None self.assertGreater(result, 0) result = self.get_success(self.store.user_last_seen_monthly_active(user_id3)) + assert result is not None self.assertNotEqual(result, 0) @override_config({"max_mau_value": 5}) diff --git a/tests/utils.py b/tests/utils.py index 65db437697..045a8b5fa7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -270,9 +270,7 @@ class MockClock: *args: P.args, **kwargs: P.kwargs, ) -> None: - # This type-ignore should be redundant once we use a mypy release with - # https://github.com/python/mypy/pull/12668. - self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) # type: ignore[arg-type] + self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None: if timer.expired: -- cgit 1.5.1 From 8e52cb0bce4c4e42a0f151f16e51529b7aba8f7d Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 30 Sep 2022 16:37:48 +0100 Subject: Revert "Update mypy and mypy-zope (#13925)" This reverts commit 6d543d6d9f56e39199b7e460d0081b02d61f12be. --- changelog.d/13925.misc | 1 - poetry.lock | 59 +++++++++++++++--------------- scripts-dev/check_pydantic_models.py | 5 ++- synapse/app/_base.py | 4 +- synapse/logging/context.py | 20 +++++----- synapse/logging/opentracing.py | 4 +- synapse/storage/database.py | 22 ++++++++--- synapse/storage/databases/main/search.py | 2 +- tests/storage/test_monthly_active_users.py | 6 --- tests/utils.py | 4 +- 10 files changed, 67 insertions(+), 60 deletions(-) delete mode 100644 changelog.d/13925.misc (limited to 'tests/storage') diff --git a/changelog.d/13925.misc b/changelog.d/13925.misc deleted file mode 100644 index f490ab122e..0000000000 --- a/changelog.d/13925.misc +++ /dev/null @@ -1 +0,0 @@ -Update mypy (0.950 -> 0.981) and mypy-zope (0.3.7 -> 0.3.11). diff --git a/poetry.lock b/poetry.lock index 63ef8573a0..0f6d1cfa69 100644 --- a/poetry.lock +++ b/poetry.lock @@ -573,11 +573,11 @@ python-versions = "*" [[package]] name = "mypy" -version = "0.981" +version = "0.950" description = "Optional static typing for Python" category = "dev" optional = false -python-versions = ">=3.7" +python-versions = ">=3.6" [package.dependencies] mypy-extensions = ">=0.4.3" @@ -600,14 +600,14 @@ python-versions = "*" [[package]] name = "mypy-zope" -version = "0.3.11" +version = "0.3.7" description = "Plugin for mypy to support zope interfaces" category = "dev" optional = false python-versions = "*" [package.dependencies] -mypy = "0.981" +mypy = "0.950" "zope.interface" = "*" "zope.schema" = "*" @@ -2162,38 +2162,37 @@ msgpack = [ {file = "msgpack-1.0.3.tar.gz", hash = "sha256:51fdc7fb93615286428ee7758cecc2f374d5ff363bdd884c7ea622a7a327a81e"}, ] mypy = [ - {file = "mypy-0.981-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4bc460e43b7785f78862dab78674e62ec3cd523485baecfdf81a555ed29ecfa0"}, - {file = "mypy-0.981-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:756fad8b263b3ba39e4e204ee53042671b660c36c9017412b43af210ddee7b08"}, - {file = "mypy-0.981-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a16a0145d6d7d00fbede2da3a3096dcc9ecea091adfa8da48fa6a7b75d35562d"}, - {file = "mypy-0.981-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce65f70b14a21fdac84c294cde75e6dbdabbcff22975335e20827b3b94bdbf49"}, - {file = "mypy-0.981-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e35d764784b42c3e256848fb8ed1d4292c9fc0098413adb28d84974c095b279"}, - {file = "mypy-0.981-cp310-cp310-win_amd64.whl", hash = "sha256:e53773073c864d5f5cec7f3fc72fbbcef65410cde8cc18d4f7242dea60dac52e"}, - {file = "mypy-0.981-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6ee196b1d10b8b215e835f438e06965d7a480f6fe016eddbc285f13955cca659"}, - {file = "mypy-0.981-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ad21d4c9d3673726cf986ea1d0c9fb66905258709550ddf7944c8f885f208be"}, - {file = "mypy-0.981-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d1debb09043e1f5ee845fa1e96d180e89115b30e47c5d3ce53bc967bab53f62d"}, - {file = "mypy-0.981-cp37-cp37m-win_amd64.whl", hash = "sha256:9f362470a3480165c4c6151786b5379351b790d56952005be18bdbdd4c7ce0ae"}, - {file = "mypy-0.981-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c9e0efb95ed6ca1654951bd5ec2f3fa91b295d78bf6527e026529d4aaa1e0c30"}, - {file = "mypy-0.981-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e178eaffc3c5cd211a87965c8c0df6da91ed7d258b5fc72b8e047c3771317ddb"}, - {file = "mypy-0.981-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:06e1eac8d99bd404ed8dd34ca29673c4346e76dd8e612ea507763dccd7e13c7a"}, - {file = "mypy-0.981-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa38f82f53e1e7beb45557ff167c177802ba7b387ad017eab1663d567017c8ee"}, - {file = "mypy-0.981-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:64e1f6af81c003f85f0dfed52db632817dabb51b65c0318ffbf5ff51995bbb08"}, - {file = "mypy-0.981-cp38-cp38-win_amd64.whl", hash = "sha256:e1acf62a8c4f7c092462c738aa2c2489e275ed386320c10b2e9bff31f6f7e8d6"}, - {file = "mypy-0.981-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b6ede64e52257931315826fdbfc6ea878d89a965580d1a65638ef77cb551f56d"}, - {file = "mypy-0.981-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eb3978b191b9fa0488524bb4ffedf2c573340e8c2b4206fc191d44c7093abfb7"}, - {file = "mypy-0.981-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77f8fcf7b4b3cc0c74fb33ae54a4cd00bb854d65645c48beccf65fa10b17882c"}, - {file = "mypy-0.981-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f64d2ce043a209a297df322eb4054dfbaa9de9e8738291706eaafda81ab2b362"}, - {file = "mypy-0.981-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2ee3dbc53d4df7e6e3b1c68ac6a971d3a4fb2852bf10a05fda228721dd44fae1"}, - {file = "mypy-0.981-cp39-cp39-win_amd64.whl", hash = "sha256:8e8e49aa9cc23aa4c926dc200ce32959d3501c4905147a66ce032f05cb5ecb92"}, - {file = "mypy-0.981-py3-none-any.whl", hash = "sha256:794f385653e2b749387a42afb1e14c2135e18daeb027e0d97162e4b7031210f8"}, - {file = "mypy-0.981.tar.gz", hash = "sha256:ad77c13037d3402fbeffda07d51e3f228ba078d1c7096a73759c9419ea031bf4"}, + {file = "mypy-0.950-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cf9c261958a769a3bd38c3e133801ebcd284ffb734ea12d01457cb09eacf7d7b"}, + {file = "mypy-0.950-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5b5bd0ffb11b4aba2bb6d31b8643902c48f990cc92fda4e21afac658044f0c0"}, + {file = "mypy-0.950-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5e7647df0f8fc947388e6251d728189cfadb3b1e558407f93254e35abc026e22"}, + {file = "mypy-0.950-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:eaff8156016487c1af5ffa5304c3e3fd183edcb412f3e9c72db349faf3f6e0eb"}, + {file = "mypy-0.950-cp310-cp310-win_amd64.whl", hash = "sha256:563514c7dc504698fb66bb1cf897657a173a496406f1866afae73ab5b3cdb334"}, + {file = "mypy-0.950-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:dd4d670eee9610bf61c25c940e9ade2d0ed05eb44227275cce88701fee014b1f"}, + {file = "mypy-0.950-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ca75ecf2783395ca3016a5e455cb322ba26b6d33b4b413fcdedfc632e67941dc"}, + {file = "mypy-0.950-cp36-cp36m-win_amd64.whl", hash = "sha256:6003de687c13196e8a1243a5e4bcce617d79b88f83ee6625437e335d89dfebe2"}, + {file = "mypy-0.950-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4c653e4846f287051599ed8f4b3c044b80e540e88feec76b11044ddc5612ffed"}, + {file = "mypy-0.950-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e19736af56947addedce4674c0971e5dceef1b5ec7d667fe86bcd2b07f8f9075"}, + {file = "mypy-0.950-cp37-cp37m-win_amd64.whl", hash = "sha256:ef7beb2a3582eb7a9f37beaf38a28acfd801988cde688760aea9e6cc4832b10b"}, + {file = "mypy-0.950-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0112752a6ff07230f9ec2f71b0d3d4e088a910fdce454fdb6553e83ed0eced7d"}, + {file = "mypy-0.950-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ee0a36edd332ed2c5208565ae6e3a7afc0eabb53f5327e281f2ef03a6bc7687a"}, + {file = "mypy-0.950-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:77423570c04aca807508a492037abbd72b12a1fb25a385847d191cd50b2c9605"}, + {file = "mypy-0.950-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5ce6a09042b6da16d773d2110e44f169683d8cc8687e79ec6d1181a72cb028d2"}, + {file = "mypy-0.950-cp38-cp38-win_amd64.whl", hash = "sha256:5b231afd6a6e951381b9ef09a1223b1feabe13625388db48a8690f8daa9b71ff"}, + {file = "mypy-0.950-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0384d9f3af49837baa92f559d3fa673e6d2652a16550a9ee07fc08c736f5e6f8"}, + {file = "mypy-0.950-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1fdeb0a0f64f2a874a4c1f5271f06e40e1e9779bf55f9567f149466fc7a55038"}, + {file = "mypy-0.950-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:61504b9a5ae166ba5ecfed9e93357fd51aa693d3d434b582a925338a2ff57fd2"}, + {file = "mypy-0.950-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a952b8bc0ae278fc6316e6384f67bb9a396eb30aced6ad034d3a76120ebcc519"}, + {file = "mypy-0.950-cp39-cp39-win_amd64.whl", hash = "sha256:eaea21d150fb26d7b4856766e7addcf929119dd19fc832b22e71d942835201ef"}, + {file = "mypy-0.950-py3-none-any.whl", hash = "sha256:a4d9898f46446bfb6405383b57b96737dcfd0a7f25b748e78ef3e8c576bba3cb"}, + {file = "mypy-0.950.tar.gz", hash = "sha256:1b333cfbca1762ff15808a0ef4f71b5d3eed8528b23ea1c3fb50543c867d68de"}, ] mypy-extensions = [ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, ] mypy-zope = [ - {file = "mypy-zope-0.3.11.tar.gz", hash = "sha256:d4255f9f04d48c79083bbd4e2fea06513a6ac7b8de06f8c4ce563fd85142ca05"}, - {file = "mypy_zope-0.3.11-py3-none-any.whl", hash = "sha256:ec080a6508d1f7805c8d2054f9fdd13c849742ce96803519e1fdfa3d3cab7140"}, + {file = "mypy-zope-0.3.7.tar.gz", hash = "sha256:9da171e78e8ef7ac8922c86af1a62f1b7f3244f121020bd94a2246bc3f33c605"}, + {file = "mypy_zope-0.3.7-py3-none-any.whl", hash = "sha256:9c7637d066e4d1bafa0651abc091c752009769098043b236446e6725be2bc9c2"}, ] netaddr = [ {file = "netaddr-0.8.0-py2.py3-none-any.whl", hash = "sha256:9666d0232c32d2656e5e5f8d735f58fd6c7457ce52fc21c98d45f2af78f990ac"}, diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py index 9f2b7ded5b..d0fb811bdb 100755 --- a/scripts-dev/check_pydantic_models.py +++ b/scripts-dev/check_pydantic_models.py @@ -88,9 +88,10 @@ def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]: @functools.wraps(factory) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - if "strict" not in kwargs: + # type-ignore: should be redundant once we can use https://github.com/python/mypy/pull/12668 + if "strict" not in kwargs: # type: ignore[attr-defined] raise MissingStrictInConstrainedTypeException(factory.__name__) - if not kwargs["strict"]: + if not kwargs["strict"]: # type: ignore[index] raise MissingStrictInConstrainedTypeException(factory.__name__) return factory(*args, **kwargs) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 000912e86e..9a24bed0a0 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -98,7 +98,9 @@ def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs) func: Function to be called when sent a SIGHUP signal. *args, **kwargs: args and kwargs to be passed to the target function. """ - _sighup_callbacks.append((func, args, kwargs)) + # This type-ignore should be redundant once we use a mypy release with + # https://github.com/python/mypy/pull/12668. + _sighup_callbacks.append((func, args, kwargs)) # type: ignore[arg-type] def start_worker_reactor( diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 6a08ffed64..fd9cb97920 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -586,7 +586,7 @@ class LoggingContextFilter(logging.Filter): True to include the record in the log output. """ context = current_context() - record.request = self._default_request + record.request = self._default_request # type: ignore # context should never be None, but if it somehow ends up being, then # we end up in a death spiral of infinite loops, so let's check, for @@ -594,21 +594,21 @@ class LoggingContextFilter(logging.Filter): if context is not None: # Logging is interested in the request ID. Note that for backwards # compatibility this is stored as the "request" on the record. - record.request = str(context) + record.request = str(context) # type: ignore # Add some data from the HTTP request. request = context.request if request is None: return True - record.ip_address = request.ip_address - record.site_tag = request.site_tag - record.requester = request.requester - record.authenticated_entity = request.authenticated_entity - record.method = request.method - record.url = request.url - record.protocol = request.protocol - record.user_agent = request.user_agent + record.ip_address = request.ip_address # type: ignore + record.site_tag = request.site_tag # type: ignore + record.requester = request.requester # type: ignore + record.authenticated_entity = request.authenticated_entity # type: ignore + record.method = request.method # type: ignore + record.url = request.url # type: ignore + record.protocol = request.protocol # type: ignore + record.user_agent = request.user_agent # type: ignore return True diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 8ce5a2a338..ca2735dd6d 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -992,9 +992,9 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: # FIXME: We could update this to handle any type of function by ignoring the # first argument only if it's named `self` or `cls`. This isn't fool-proof # but handles the idiomatic cases. - for i, arg in enumerate(args[1:], start=1): + for i, arg in enumerate(args[1:], start=1): # type: ignore[index] set_tag(SynapseTags.FUNC_ARG_PREFIX + argspec.args[i], str(arg)) - set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :])) + set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :])) # type: ignore[index] set_tag(SynapseTags.FUNC_KWARGS, str(kwargs)) yield diff --git a/synapse/storage/database.py b/synapse/storage/database.py index a252f8eaa0..bb28ded1b5 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -290,7 +290,8 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.after_callbacks is not None - self.after_callbacks.append((callback, args, kwargs)) + # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 + self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] def async_call_after( self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs @@ -311,7 +312,8 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.async_after_callbacks is not None - self.async_after_callbacks.append((callback, args, kwargs)) + # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 + self.async_after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] def call_on_exception( self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs @@ -329,7 +331,8 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.exception_callbacks is not None - self.exception_callbacks.append((callback, args, kwargs)) + # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 + self.exception_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] def fetchone(self) -> Optional[Tuple]: return self.txn.fetchone() @@ -418,7 +421,10 @@ class LoggingTransaction: sql = self.database_engine.convert_param_style(sql) if args: try: - sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) + # The type-ignore should be redundant once mypy releases a version with + # https://github.com/python/mypy/pull/12668. (`args` might be empty, + # (but we'll catch the index error if so.) + sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) # type: ignore[index] except Exception: # Don't let logging failures stop SQL from working pass @@ -649,7 +655,9 @@ class DatabasePool: # For now, we just log an error, and hope that it works on the first attempt. # TODO: raise an exception. - for i, arg in enumerate(args): + # Type-ignore Mypy doesn't yet consider ParamSpec.args to be iterable; see + # https://github.com/python/mypy/pull/12668 + for i, arg in enumerate(args): # type: ignore[arg-type, var-annotated] if inspect.isgenerator(arg): logger.error( "Programming error: generator passed to new_transaction as " @@ -657,7 +665,9 @@ class DatabasePool: i, func, ) - for name, val in kwargs.items(): + # Type-ignore Mypy doesn't yet consider ParamSpec.args to be a mapping; see + # https://github.com/python/mypy/pull/12668 + for name, val in kwargs.items(): # type: ignore[attr-defined] if inspect.isgenerator(val): logger.error( "Programming error: generator passed to new_transaction as " diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 1b79acf955..f6e24b68d2 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -641,7 +641,7 @@ class SearchStore(SearchBackgroundUpdateStore): raise Exception("Unrecognized database engine") # mypy expects to append only a `str`, not an `int` - args.append(limit) + args.append(limit) # type: ignore[arg-type] results = await self.db_pool.execute( "search_rooms", self.db_pool.cursor_to_dict, sql, *args diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 3da8221109..e8b4a5644b 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -96,12 +96,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # Test each of the registered users is marked as active timestamp = self.get_success(self.store.user_last_seen_monthly_active(user1)) - # Mypy notes that one shouldn't compare Optional[int] to 0 with assertGreater. - # Check that timestamp really is an int. - assert timestamp is not None self.assertGreater(timestamp, 0) timestamp = self.get_success(self.store.user_last_seen_monthly_active(user2)) - assert timestamp is not None self.assertGreater(timestamp, 0) # Test that users with reserved 3pids are not removed from the MAU table @@ -170,11 +166,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(self.store.upsert_monthly_active_user(user_id2)) result = self.get_success(self.store.user_last_seen_monthly_active(user_id1)) - assert result is not None self.assertGreater(result, 0) result = self.get_success(self.store.user_last_seen_monthly_active(user_id3)) - assert result is not None self.assertNotEqual(result, 0) @override_config({"max_mau_value": 5}) diff --git a/tests/utils.py b/tests/utils.py index 045a8b5fa7..65db437697 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -270,7 +270,9 @@ class MockClock: *args: P.args, **kwargs: P.kwargs, ) -> None: - self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) + # This type-ignore should be redundant once we use a mypy release with + # https://github.com/python/mypy/pull/12668. + self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) # type: ignore[arg-type] def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None: if timer.expired: -- cgit 1.5.1 From 285d72556bb3c36f075b336b2bdd6acb08391ad5 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 30 Sep 2022 17:36:28 +0100 Subject: Update mypy and mypy-zope, attempt 3 (#13993) Co-authored-by: Patrick Cloke --- changelog.d/13925.misc | 1 + changelog.d/13993.misc | 1 + poetry.lock | 59 +++++++++++++++--------------- scripts-dev/check_pydantic_models.py | 5 +-- synapse/app/_base.py | 4 +- synapse/logging/context.py | 20 +++++----- synapse/logging/opentracing.py | 4 +- synapse/storage/database.py | 22 +++-------- synapse/storage/databases/main/search.py | 2 +- tests/storage/test_monthly_active_users.py | 7 +++- tests/utils.py | 4 +- 11 files changed, 61 insertions(+), 68 deletions(-) create mode 100644 changelog.d/13925.misc create mode 100644 changelog.d/13993.misc (limited to 'tests/storage') diff --git a/changelog.d/13925.misc b/changelog.d/13925.misc new file mode 100644 index 0000000000..f490ab122e --- /dev/null +++ b/changelog.d/13925.misc @@ -0,0 +1 @@ +Update mypy (0.950 -> 0.981) and mypy-zope (0.3.7 -> 0.3.11). diff --git a/changelog.d/13993.misc b/changelog.d/13993.misc new file mode 100644 index 0000000000..f490ab122e --- /dev/null +++ b/changelog.d/13993.misc @@ -0,0 +1 @@ +Update mypy (0.950 -> 0.981) and mypy-zope (0.3.7 -> 0.3.11). diff --git a/poetry.lock b/poetry.lock index 0f6d1cfa69..63ef8573a0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -573,11 +573,11 @@ python-versions = "*" [[package]] name = "mypy" -version = "0.950" +version = "0.981" description = "Optional static typing for Python" category = "dev" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] mypy-extensions = ">=0.4.3" @@ -600,14 +600,14 @@ python-versions = "*" [[package]] name = "mypy-zope" -version = "0.3.7" +version = "0.3.11" description = "Plugin for mypy to support zope interfaces" category = "dev" optional = false python-versions = "*" [package.dependencies] -mypy = "0.950" +mypy = "0.981" "zope.interface" = "*" "zope.schema" = "*" @@ -2162,37 +2162,38 @@ msgpack = [ {file = "msgpack-1.0.3.tar.gz", hash = "sha256:51fdc7fb93615286428ee7758cecc2f374d5ff363bdd884c7ea622a7a327a81e"}, ] mypy = [ - {file = "mypy-0.950-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cf9c261958a769a3bd38c3e133801ebcd284ffb734ea12d01457cb09eacf7d7b"}, - {file = "mypy-0.950-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5b5bd0ffb11b4aba2bb6d31b8643902c48f990cc92fda4e21afac658044f0c0"}, - {file = "mypy-0.950-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5e7647df0f8fc947388e6251d728189cfadb3b1e558407f93254e35abc026e22"}, - {file = "mypy-0.950-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:eaff8156016487c1af5ffa5304c3e3fd183edcb412f3e9c72db349faf3f6e0eb"}, - {file = "mypy-0.950-cp310-cp310-win_amd64.whl", hash = "sha256:563514c7dc504698fb66bb1cf897657a173a496406f1866afae73ab5b3cdb334"}, - {file = "mypy-0.950-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:dd4d670eee9610bf61c25c940e9ade2d0ed05eb44227275cce88701fee014b1f"}, - {file = "mypy-0.950-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ca75ecf2783395ca3016a5e455cb322ba26b6d33b4b413fcdedfc632e67941dc"}, - {file = "mypy-0.950-cp36-cp36m-win_amd64.whl", hash = "sha256:6003de687c13196e8a1243a5e4bcce617d79b88f83ee6625437e335d89dfebe2"}, - {file = "mypy-0.950-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4c653e4846f287051599ed8f4b3c044b80e540e88feec76b11044ddc5612ffed"}, - {file = "mypy-0.950-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e19736af56947addedce4674c0971e5dceef1b5ec7d667fe86bcd2b07f8f9075"}, - {file = "mypy-0.950-cp37-cp37m-win_amd64.whl", hash = "sha256:ef7beb2a3582eb7a9f37beaf38a28acfd801988cde688760aea9e6cc4832b10b"}, - {file = "mypy-0.950-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0112752a6ff07230f9ec2f71b0d3d4e088a910fdce454fdb6553e83ed0eced7d"}, - {file = "mypy-0.950-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ee0a36edd332ed2c5208565ae6e3a7afc0eabb53f5327e281f2ef03a6bc7687a"}, - {file = "mypy-0.950-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:77423570c04aca807508a492037abbd72b12a1fb25a385847d191cd50b2c9605"}, - {file = "mypy-0.950-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5ce6a09042b6da16d773d2110e44f169683d8cc8687e79ec6d1181a72cb028d2"}, - {file = "mypy-0.950-cp38-cp38-win_amd64.whl", hash = "sha256:5b231afd6a6e951381b9ef09a1223b1feabe13625388db48a8690f8daa9b71ff"}, - {file = "mypy-0.950-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0384d9f3af49837baa92f559d3fa673e6d2652a16550a9ee07fc08c736f5e6f8"}, - {file = "mypy-0.950-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1fdeb0a0f64f2a874a4c1f5271f06e40e1e9779bf55f9567f149466fc7a55038"}, - {file = "mypy-0.950-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:61504b9a5ae166ba5ecfed9e93357fd51aa693d3d434b582a925338a2ff57fd2"}, - {file = "mypy-0.950-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a952b8bc0ae278fc6316e6384f67bb9a396eb30aced6ad034d3a76120ebcc519"}, - {file = "mypy-0.950-cp39-cp39-win_amd64.whl", hash = "sha256:eaea21d150fb26d7b4856766e7addcf929119dd19fc832b22e71d942835201ef"}, - {file = "mypy-0.950-py3-none-any.whl", hash = "sha256:a4d9898f46446bfb6405383b57b96737dcfd0a7f25b748e78ef3e8c576bba3cb"}, - {file = "mypy-0.950.tar.gz", hash = "sha256:1b333cfbca1762ff15808a0ef4f71b5d3eed8528b23ea1c3fb50543c867d68de"}, + {file = "mypy-0.981-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4bc460e43b7785f78862dab78674e62ec3cd523485baecfdf81a555ed29ecfa0"}, + {file = "mypy-0.981-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:756fad8b263b3ba39e4e204ee53042671b660c36c9017412b43af210ddee7b08"}, + {file = "mypy-0.981-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a16a0145d6d7d00fbede2da3a3096dcc9ecea091adfa8da48fa6a7b75d35562d"}, + {file = "mypy-0.981-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce65f70b14a21fdac84c294cde75e6dbdabbcff22975335e20827b3b94bdbf49"}, + {file = "mypy-0.981-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e35d764784b42c3e256848fb8ed1d4292c9fc0098413adb28d84974c095b279"}, + {file = "mypy-0.981-cp310-cp310-win_amd64.whl", hash = "sha256:e53773073c864d5f5cec7f3fc72fbbcef65410cde8cc18d4f7242dea60dac52e"}, + {file = "mypy-0.981-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6ee196b1d10b8b215e835f438e06965d7a480f6fe016eddbc285f13955cca659"}, + {file = "mypy-0.981-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ad21d4c9d3673726cf986ea1d0c9fb66905258709550ddf7944c8f885f208be"}, + {file = "mypy-0.981-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d1debb09043e1f5ee845fa1e96d180e89115b30e47c5d3ce53bc967bab53f62d"}, + {file = "mypy-0.981-cp37-cp37m-win_amd64.whl", hash = "sha256:9f362470a3480165c4c6151786b5379351b790d56952005be18bdbdd4c7ce0ae"}, + {file = "mypy-0.981-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c9e0efb95ed6ca1654951bd5ec2f3fa91b295d78bf6527e026529d4aaa1e0c30"}, + {file = "mypy-0.981-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e178eaffc3c5cd211a87965c8c0df6da91ed7d258b5fc72b8e047c3771317ddb"}, + {file = "mypy-0.981-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:06e1eac8d99bd404ed8dd34ca29673c4346e76dd8e612ea507763dccd7e13c7a"}, + {file = "mypy-0.981-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa38f82f53e1e7beb45557ff167c177802ba7b387ad017eab1663d567017c8ee"}, + {file = "mypy-0.981-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:64e1f6af81c003f85f0dfed52db632817dabb51b65c0318ffbf5ff51995bbb08"}, + {file = "mypy-0.981-cp38-cp38-win_amd64.whl", hash = "sha256:e1acf62a8c4f7c092462c738aa2c2489e275ed386320c10b2e9bff31f6f7e8d6"}, + {file = "mypy-0.981-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b6ede64e52257931315826fdbfc6ea878d89a965580d1a65638ef77cb551f56d"}, + {file = "mypy-0.981-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eb3978b191b9fa0488524bb4ffedf2c573340e8c2b4206fc191d44c7093abfb7"}, + {file = "mypy-0.981-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77f8fcf7b4b3cc0c74fb33ae54a4cd00bb854d65645c48beccf65fa10b17882c"}, + {file = "mypy-0.981-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f64d2ce043a209a297df322eb4054dfbaa9de9e8738291706eaafda81ab2b362"}, + {file = "mypy-0.981-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2ee3dbc53d4df7e6e3b1c68ac6a971d3a4fb2852bf10a05fda228721dd44fae1"}, + {file = "mypy-0.981-cp39-cp39-win_amd64.whl", hash = "sha256:8e8e49aa9cc23aa4c926dc200ce32959d3501c4905147a66ce032f05cb5ecb92"}, + {file = "mypy-0.981-py3-none-any.whl", hash = "sha256:794f385653e2b749387a42afb1e14c2135e18daeb027e0d97162e4b7031210f8"}, + {file = "mypy-0.981.tar.gz", hash = "sha256:ad77c13037d3402fbeffda07d51e3f228ba078d1c7096a73759c9419ea031bf4"}, ] mypy-extensions = [ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, ] mypy-zope = [ - {file = "mypy-zope-0.3.7.tar.gz", hash = "sha256:9da171e78e8ef7ac8922c86af1a62f1b7f3244f121020bd94a2246bc3f33c605"}, - {file = "mypy_zope-0.3.7-py3-none-any.whl", hash = "sha256:9c7637d066e4d1bafa0651abc091c752009769098043b236446e6725be2bc9c2"}, + {file = "mypy-zope-0.3.11.tar.gz", hash = "sha256:d4255f9f04d48c79083bbd4e2fea06513a6ac7b8de06f8c4ce563fd85142ca05"}, + {file = "mypy_zope-0.3.11-py3-none-any.whl", hash = "sha256:ec080a6508d1f7805c8d2054f9fdd13c849742ce96803519e1fdfa3d3cab7140"}, ] netaddr = [ {file = "netaddr-0.8.0-py2.py3-none-any.whl", hash = "sha256:9666d0232c32d2656e5e5f8d735f58fd6c7457ce52fc21c98d45f2af78f990ac"}, diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py index d0fb811bdb..9f2b7ded5b 100755 --- a/scripts-dev/check_pydantic_models.py +++ b/scripts-dev/check_pydantic_models.py @@ -88,10 +88,9 @@ def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]: @functools.wraps(factory) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - # type-ignore: should be redundant once we can use https://github.com/python/mypy/pull/12668 - if "strict" not in kwargs: # type: ignore[attr-defined] + if "strict" not in kwargs: raise MissingStrictInConstrainedTypeException(factory.__name__) - if not kwargs["strict"]: # type: ignore[index] + if not kwargs["strict"]: raise MissingStrictInConstrainedTypeException(factory.__name__) return factory(*args, **kwargs) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 9a24bed0a0..000912e86e 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -98,9 +98,7 @@ def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs) func: Function to be called when sent a SIGHUP signal. *args, **kwargs: args and kwargs to be passed to the target function. """ - # This type-ignore should be redundant once we use a mypy release with - # https://github.com/python/mypy/pull/12668. - _sighup_callbacks.append((func, args, kwargs)) # type: ignore[arg-type] + _sighup_callbacks.append((func, args, kwargs)) def start_worker_reactor( diff --git a/synapse/logging/context.py b/synapse/logging/context.py index fd9cb97920..6a08ffed64 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -586,7 +586,7 @@ class LoggingContextFilter(logging.Filter): True to include the record in the log output. """ context = current_context() - record.request = self._default_request # type: ignore + record.request = self._default_request # context should never be None, but if it somehow ends up being, then # we end up in a death spiral of infinite loops, so let's check, for @@ -594,21 +594,21 @@ class LoggingContextFilter(logging.Filter): if context is not None: # Logging is interested in the request ID. Note that for backwards # compatibility this is stored as the "request" on the record. - record.request = str(context) # type: ignore + record.request = str(context) # Add some data from the HTTP request. request = context.request if request is None: return True - record.ip_address = request.ip_address # type: ignore - record.site_tag = request.site_tag # type: ignore - record.requester = request.requester # type: ignore - record.authenticated_entity = request.authenticated_entity # type: ignore - record.method = request.method # type: ignore - record.url = request.url # type: ignore - record.protocol = request.protocol # type: ignore - record.user_agent = request.user_agent # type: ignore + record.ip_address = request.ip_address + record.site_tag = request.site_tag + record.requester = request.requester + record.authenticated_entity = request.authenticated_entity + record.method = request.method + record.url = request.url + record.protocol = request.protocol + record.user_agent = request.user_agent return True diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index ca2735dd6d..8ce5a2a338 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -992,9 +992,9 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: # FIXME: We could update this to handle any type of function by ignoring the # first argument only if it's named `self` or `cls`. This isn't fool-proof # but handles the idiomatic cases. - for i, arg in enumerate(args[1:], start=1): # type: ignore[index] + for i, arg in enumerate(args[1:], start=1): set_tag(SynapseTags.FUNC_ARG_PREFIX + argspec.args[i], str(arg)) - set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :])) # type: ignore[index] + set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :])) set_tag(SynapseTags.FUNC_KWARGS, str(kwargs)) yield diff --git a/synapse/storage/database.py b/synapse/storage/database.py index bb28ded1b5..a252f8eaa0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -290,8 +290,7 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.after_callbacks is not None - # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 - self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] + self.after_callbacks.append((callback, args, kwargs)) def async_call_after( self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs @@ -312,8 +311,7 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.async_after_callbacks is not None - # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 - self.async_after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] + self.async_after_callbacks.append((callback, args, kwargs)) def call_on_exception( self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs @@ -331,8 +329,7 @@ class LoggingTransaction: # LoggingTransaction isn't expecting there to be any callbacks; assert that # is not the case. assert self.exception_callbacks is not None - # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 - self.exception_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] + self.exception_callbacks.append((callback, args, kwargs)) def fetchone(self) -> Optional[Tuple]: return self.txn.fetchone() @@ -421,10 +418,7 @@ class LoggingTransaction: sql = self.database_engine.convert_param_style(sql) if args: try: - # The type-ignore should be redundant once mypy releases a version with - # https://github.com/python/mypy/pull/12668. (`args` might be empty, - # (but we'll catch the index error if so.) - sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) # type: ignore[index] + sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) except Exception: # Don't let logging failures stop SQL from working pass @@ -655,9 +649,7 @@ class DatabasePool: # For now, we just log an error, and hope that it works on the first attempt. # TODO: raise an exception. - # Type-ignore Mypy doesn't yet consider ParamSpec.args to be iterable; see - # https://github.com/python/mypy/pull/12668 - for i, arg in enumerate(args): # type: ignore[arg-type, var-annotated] + for i, arg in enumerate(args): if inspect.isgenerator(arg): logger.error( "Programming error: generator passed to new_transaction as " @@ -665,9 +657,7 @@ class DatabasePool: i, func, ) - # Type-ignore Mypy doesn't yet consider ParamSpec.args to be a mapping; see - # https://github.com/python/mypy/pull/12668 - for name, val in kwargs.items(): # type: ignore[attr-defined] + for name, val in kwargs.items(): if inspect.isgenerator(val): logger.error( "Programming error: generator passed to new_transaction as " diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index f6e24b68d2..1b79acf955 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -641,7 +641,7 @@ class SearchStore(SearchBackgroundUpdateStore): raise Exception("Unrecognized database engine") # mypy expects to append only a `str`, not an `int` - args.append(limit) # type: ignore[arg-type] + args.append(limit) results = await self.db_pool.execute( "search_rooms", self.db_pool.cursor_to_dict, sql, *args diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index e8b4a5644b..c55c4db970 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -96,8 +96,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # Test each of the registered users is marked as active timestamp = self.get_success(self.store.user_last_seen_monthly_active(user1)) + # Mypy notes that one shouldn't compare Optional[int] to 0 with assertGreater. + # Check that timestamp really is an int. + assert timestamp is not None self.assertGreater(timestamp, 0) timestamp = self.get_success(self.store.user_last_seen_monthly_active(user2)) + assert timestamp is not None self.assertGreater(timestamp, 0) # Test that users with reserved 3pids are not removed from the MAU table @@ -166,10 +170,11 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(self.store.upsert_monthly_active_user(user_id2)) result = self.get_success(self.store.user_last_seen_monthly_active(user_id1)) + assert result is not None self.assertGreater(result, 0) result = self.get_success(self.store.user_last_seen_monthly_active(user_id3)) - self.assertNotEqual(result, 0) + self.assertIsNone(result) @override_config({"max_mau_value": 5}) def test_reap_monthly_active_users(self): diff --git a/tests/utils.py b/tests/utils.py index 65db437697..045a8b5fa7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -270,9 +270,7 @@ class MockClock: *args: P.args, **kwargs: P.kwargs, ) -> None: - # This type-ignore should be redundant once we use a mypy release with - # https://github.com/python/mypy/pull/12668. - self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) # type: ignore[arg-type] + self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None: if timer.expired: -- cgit 1.5.1 From b4ec4f5e71a87d5bdc840a4220dfd9a34c54c847 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Oct 2022 09:47:04 -0400 Subject: Track notification counts per thread (implement MSC3773). (#13776) When retrieving counts of notifications segment the results based on the thread ID, but choose whether to return them as individual threads or as a single summed field by letting the client opt-in via a sync flag. The summarization code is also updated to be per thread, instead of per room. --- changelog.d/13776.feature | 1 + synapse/api/constants.py | 3 + synapse/api/filtering.py | 10 ++ synapse/config/experimental.py | 2 + synapse/handlers/sync.py | 40 ++++- synapse/push/bulk_push_rule_evaluator.py | 4 +- synapse/push/push_tools.py | 9 +- synapse/rest/client/sync.py | 4 + synapse/rest/client/versions.py | 3 +- synapse/storage/database.py | 2 +- .../storage/databases/main/event_push_actions.py | 188 +++++++++++++-------- synapse/storage/schema/__init__.py | 6 +- .../delta/73/06thread_notifications_backfill.sql | 29 ++++ .../07thread_notifications_not_null.sql.postgres | 19 +++ .../73/07thread_notifications_not_null.sql.sqlite | 101 +++++++++++ tests/replication/slave/storage/test_events.py | 17 +- tests/storage/test_event_push_actions.py | 169 +++++++++++++++++- 17 files changed, 514 insertions(+), 93 deletions(-) create mode 100644 changelog.d/13776.feature create mode 100644 synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql create mode 100644 synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres create mode 100644 synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite (limited to 'tests/storage') diff --git a/changelog.d/13776.feature b/changelog.d/13776.feature new file mode 100644 index 0000000000..22bce125ce --- /dev/null +++ b/changelog.d/13776.feature @@ -0,0 +1 @@ +Experimental support for thread-specific notifications ([MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index c031903b1a..44c5ffc6a5 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -31,6 +31,9 @@ MAX_ALIAS_LENGTH = 255 # the maximum length for a user id is 255 characters MAX_USERID_LENGTH = 255 +# Constant value used for the pseudo-thread which is the main timeline. +MAIN_TIMELINE: Final = "main" + class Membership: diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index f7f46f8d80..c6e44dcf82 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -84,6 +84,7 @@ ROOM_EVENT_FILTER_SCHEMA = { "contains_url": {"type": "boolean"}, "lazy_load_members": {"type": "boolean"}, "include_redundant_members": {"type": "boolean"}, + "org.matrix.msc3773.unread_thread_notifications": {"type": "boolean"}, # Include or exclude events with the provided labels. # cf https://github.com/matrix-org/matrix-doc/pull/2326 "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, @@ -240,6 +241,9 @@ class FilterCollection: def include_redundant_members(self) -> bool: return self._room_state_filter.include_redundant_members + def unread_thread_notifications(self) -> bool: + return self._room_timeline_filter.unread_thread_notifications + async def filter_presence( self, events: Iterable[UserPresenceState] ) -> List[UserPresenceState]: @@ -304,6 +308,12 @@ class Filter: self.include_redundant_members = filter_json.get( "include_redundant_members", False ) + if hs.config.experimental.msc3773_enabled: + self.unread_thread_notifications: bool = filter_json.get( + "org.matrix.msc3773.unread_thread_notifications", False + ) + else: + self.unread_thread_notifications = False self.types = filter_json.get("types", None) self.not_types = filter_json.get("not_types", []) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 83695f24d9..6503ce6e34 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -99,6 +99,8 @@ class ExperimentalConfig(Config): self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False) # MSC3772: A push rule for mutual relations. self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) + # MSC3773: Thread notifications + self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) # MSC3715: dir param on /relations. self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4abb9b6127..329e89c604 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -40,7 +40,7 @@ from synapse.handlers.relations import BundledAggregations from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user -from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.storage.databases.main.event_push_actions import RoomNotifCounts from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -128,6 +128,7 @@ class JoinedSyncResult: ephemeral: List[JsonDict] account_data: List[JsonDict] unread_notifications: JsonDict + unread_thread_notifications: JsonDict summary: Optional[JsonDict] unread_count: int @@ -278,6 +279,8 @@ class SyncHandler: self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync + self._msc3773_enabled = hs.config.experimental.msc3773_enabled + async def wait_for_sync_for_user( self, requester: Requester, @@ -1288,7 +1291,7 @@ class SyncHandler: async def unread_notifs_for_room_id( self, room_id: str, sync_config: SyncConfig - ) -> NotifCounts: + ) -> RoomNotifCounts: with Measure(self.clock, "unread_notifs_for_room_id"): return await self.store.get_unread_event_push_actions_by_room_for_user( @@ -2353,6 +2356,7 @@ class SyncHandler: ephemeral=ephemeral, account_data=account_data_events, unread_notifications=unread_notifications, + unread_thread_notifications={}, summary=summary, unread_count=0, ) @@ -2360,10 +2364,36 @@ class SyncHandler: if room_sync or always_include: notifs = await self.unread_notifs_for_room_id(room_id, sync_config) - unread_notifications["notification_count"] = notifs.notify_count - unread_notifications["highlight_count"] = notifs.highlight_count + # Notifications for the main timeline. + notify_count = notifs.main_timeline.notify_count + highlight_count = notifs.main_timeline.highlight_count + unread_count = notifs.main_timeline.unread_count - room_sync.unread_count = notifs.unread_count + # Check the sync configuration. + if ( + self._msc3773_enabled + and sync_config.filter_collection.unread_thread_notifications() + ): + # And add info for each thread. + room_sync.unread_thread_notifications = { + thread_id: { + "notification_count": thread_notifs.notify_count, + "highlight_count": thread_notifs.highlight_count, + } + for thread_id, thread_notifs in notifs.threads.items() + if thread_id is not None + } + + else: + # Combine the unread counts for all threads and main timeline. + for thread_notifs in notifs.threads.values(): + notify_count += thread_notifs.notify_count + highlight_count += thread_notifs.highlight_count + unread_count += thread_notifs.unread_count + + unread_notifications["notification_count"] = notify_count + unread_notifications["highlight_count"] = highlight_count + room_sync.unread_count = unread_count sync_result_builder.joined.append(room_sync) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 4270438918..61d952742d 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -31,7 +31,7 @@ from typing import ( from prometheus_client import Counter -from synapse.api.constants import EventTypes, Membership, RelationTypes +from synapse.api.constants import MAIN_TIMELINE, EventTypes, Membership, RelationTypes from synapse.event_auth import auth_types_for_event, get_user_power_level from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext @@ -280,7 +280,7 @@ class BulkPushRuleEvaluator: # If the event does not have a relation, then cannot have any mutual # relations or thread ID. relations = {} - thread_id = "main" + thread_id = MAIN_TIMELINE if relation: relations = await self._get_mutual_relations( relation.parent_id, diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 658bf373b7..edeba27a45 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -39,7 +39,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - await concurrently_execute(get_room_unread_count, joins, 10) for notifs in room_notifs: - if notifs.notify_count == 0: + # Combine the counts from all the threads. + notify_count = notifs.main_timeline.notify_count + sum( + n.notify_count for n in notifs.threads.values() + ) + + if notify_count == 0: continue if group_by_room: @@ -47,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - badge += 1 else: # increment the badge count by the number of unread messages in the room - badge += notifs.notify_count + badge += notify_count return badge diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index c2989765ce..f1c23d68e5 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -509,6 +509,10 @@ class SyncRestServlet(RestServlet): ephemeral_events = room.ephemeral result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications + if room.unread_thread_notifications: + result[ + "org.matrix.msc3773.unread_thread_notifications" + ] = room.unread_thread_notifications result["summary"] = room.summary if self._msc2654_enabled: result["org.matrix.msc2654.unread_count"] = room.unread_count diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c95b0d6f19..280d306483 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -103,8 +103,9 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above - # Support for thread read receipts. + # Support for thread read receipts & notification counts. "org.matrix.msc3771": self.config.experimental.msc3771_enabled, + "org.matrix.msc3773": self.config.experimental.msc3773_enabled, # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, # Adds support for login token requests as per MSC3882 diff --git a/synapse/storage/database.py b/synapse/storage/database.py index b4469eb964..7bb21f8f81 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -94,7 +94,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "event_search": "event_search_event_id_idx", "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx", "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx", - "event_push_summary": "event_push_summary_unique_index", + "event_push_summary": "event_push_summary_unique_index2", "receipts_linearized": "receipts_linearized_unique_index", "receipts_graph": "receipts_graph_unique_index", } diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index cdc9ee5a37..3210e9cca1 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -88,7 +88,7 @@ from typing import ( import attr -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -157,7 +157,7 @@ class UserPushAction(EmailPushAction): @attr.s(slots=True, auto_attribs=True) class NotifCounts: """ - The per-user, per-room count of notifications. Used by sync and push. + The per-user, per-room, per-thread count of notifications. Used by sync and push. """ notify_count: int = 0 @@ -165,6 +165,21 @@ class NotifCounts: highlight_count: int = 0 +@attr.s(slots=True, auto_attribs=True) +class RoomNotifCounts: + """ + The per-user, per-room count of notifications. Used by sync and push. + """ + + main_timeline: NotifCounts + # Map of thread ID to the notification counts. + threads: Dict[str, NotifCounts] + + def __len__(self) -> int: + # To properly account for the amount of space in any caches. + return len(self.threads) + 1 + + def _serialize_action( actions: Collection[Union[Mapping, str]], is_highlight: bool ) -> str: @@ -338,12 +353,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return result - @cached(tree=True, max_entries=5000) + @cached(tree=True, max_entries=5000, iterable=True) async def get_unread_event_push_actions_by_room_for_user( self, room_id: str, user_id: str, - ) -> NotifCounts: + ) -> RoomNotifCounts: """Get the notification count, the highlight count and the unread message count for a given user in a given room after their latest read receipt. @@ -356,8 +371,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: The user to retrieve the counts for. Returns - A NotifCounts object containing the notification count, the highlight count - and the unread message count. + A RoomNotifCounts object containing the notification count, the + highlight count and the unread message count for both the main timeline + and threads. """ return await self.db_pool.runInteraction( "get_unread_event_push_actions_by_room", @@ -371,7 +387,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn: LoggingTransaction, room_id: str, user_id: str, - ) -> NotifCounts: + ) -> RoomNotifCounts: # Get the stream ordering of the user's latest receipt in the room. result = self.get_last_unthreaded_receipt_for_user_txn( txn, @@ -406,7 +422,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas room_id: str, user_id: str, receipt_stream_ordering: int, - ) -> NotifCounts: + ) -> RoomNotifCounts: """Get the number of unread messages for a user/room that have happened since the given stream ordering. @@ -418,12 +434,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas receipt in the room. If there are no receipts, the stream ordering of the user's join event. - Returns - A NotifCounts object containing the notification count, the highlight count - and the unread message count. + Returns: + A RoomNotifCounts object containing the notification count, the + highlight count and the unread message count for both the main timeline + and threads. """ - counts = NotifCounts() + main_counts = NotifCounts() + thread_counts: Dict[str, NotifCounts] = {} + + def _get_thread(thread_id: str) -> NotifCounts: + if thread_id == MAIN_TIMELINE: + return main_counts + return thread_counts.setdefault(thread_id, NotifCounts()) # First we pull the counts from the summary table. # @@ -440,52 +463,61 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # receipt). txn.execute( """ - SELECT stream_ordering, notif_count, COALESCE(unread_count, 0) + SELECT stream_ordering, notif_count, COALESCE(unread_count, 0), thread_id FROM event_push_summary WHERE room_id = ? AND user_id = ? AND ( (last_receipt_stream_ordering IS NULL AND stream_ordering > ?) OR last_receipt_stream_ordering = ? - ) + ) AND (notif_count != 0 OR COALESCE(unread_count, 0) != 0) """, (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering), ) - row = txn.fetchone() - - summary_stream_ordering = 0 - if row: - summary_stream_ordering = row[0] - counts.notify_count += row[1] - counts.unread_count += row[2] + max_summary_stream_ordering = 0 + for summary_stream_ordering, notif_count, unread_count, thread_id in txn: + counts = _get_thread(thread_id) + counts.notify_count += notif_count + counts.unread_count += unread_count + + # Summaries will only be used if they have not been invalidated by + # a recent receipt; track the latest stream ordering or a valid summary. + # + # Note that since there's only one read receipt in the room per user, + # valid summaries are contiguous. + max_summary_stream_ordering = max( + summary_stream_ordering, max_summary_stream_ordering + ) # Next we need to count highlights, which aren't summarised sql = """ - SELECT COUNT(*) FROM event_push_actions + SELECT COUNT(*), thread_id FROM event_push_actions WHERE user_id = ? AND room_id = ? AND stream_ordering > ? AND highlight = 1 + GROUP BY thread_id """ txn.execute(sql, (user_id, room_id, receipt_stream_ordering)) - row = txn.fetchone() - if row: - counts.highlight_count += row[0] + for highlight_count, thread_id in txn: + _get_thread(thread_id).highlight_count += highlight_count # Finally we need to count push actions that aren't included in the # summary returned above. This might be due to recent events that haven't # been summarised yet or the summary is out of date due to a recent read # receipt. start_unread_stream_ordering = max( - receipt_stream_ordering, summary_stream_ordering + receipt_stream_ordering, max_summary_stream_ordering ) - notify_count, unread_count = self._get_notif_unread_count_for_user_room( + unread_counts = self._get_notif_unread_count_for_user_room( txn, room_id, user_id, start_unread_stream_ordering ) - counts.notify_count += notify_count - counts.unread_count += unread_count + for notif_count, unread_count, thread_id in unread_counts: + counts = _get_thread(thread_id) + counts.notify_count += notif_count + counts.unread_count += unread_count - return counts + return RoomNotifCounts(main_counts, thread_counts) def _get_notif_unread_count_for_user_room( self, @@ -494,7 +526,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: str, stream_ordering: int, max_stream_ordering: Optional[int] = None, - ) -> Tuple[int, int]: + ) -> List[Tuple[int, int, str]]: """Returns the notify and unread counts from `event_push_actions` for the given user/room in the given range. @@ -510,13 +542,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas If this is not given, then no maximum is applied. Return: - A tuple of the notif count and unread count in the given range. + A tuple of the notif count and unread count in the given range for + each thread. """ # If there have been no events in the room since the stream ordering, # there can't be any push actions either. if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering): - return 0, 0 + return [] clause = "" args = [user_id, room_id, stream_ordering] @@ -527,26 +560,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # If the max stream ordering is less than the min stream ordering, # then obviously there are zero push actions in that range. if max_stream_ordering <= stream_ordering: - return 0, 0 + return [] sql = f""" SELECT COUNT(CASE WHEN notif = 1 THEN 1 END), - COUNT(CASE WHEN unread = 1 THEN 1 END) - FROM event_push_actions ea - WHERE user_id = ? + COUNT(CASE WHEN unread = 1 THEN 1 END), + thread_id + FROM event_push_actions ea + WHERE user_id = ? AND room_id = ? AND ea.stream_ordering > ? {clause} + GROUP BY thread_id """ txn.execute(sql, args) - row = txn.fetchone() - - if row: - return cast(Tuple[int, int], row) - - return 0, 0 + return cast(List[Tuple[int, int, str]], txn.fetchall()) async def get_push_action_users_in_range( self, min_stream_ordering: int, max_stream_ordering: int @@ -1099,26 +1129,34 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # Fetch the notification counts between the stream ordering of the # latest receipt and what was previously summarised. - notif_count, unread_count = self._get_notif_unread_count_for_user_room( + unread_counts = self._get_notif_unread_count_for_user_room( txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering ) - # Replace the previous summary with the new counts. - # - # TODO(threads): Upsert per-thread instead of setting them all to main. - self.db_pool.simple_upsert_txn( + # First mark the summary for all threads in the room as cleared. + self.db_pool.simple_update_txn( txn, table="event_push_summary", - keyvalues={"room_id": room_id, "user_id": user_id}, - values={ - "notif_count": notif_count, - "unread_count": unread_count, + keyvalues={"user_id": user_id, "room_id": room_id}, + updatevalues={ + "notif_count": 0, + "unread_count": 0, "stream_ordering": old_rotate_stream_ordering, "last_receipt_stream_ordering": stream_ordering, - "thread_id": "main", }, ) + # Then any updated threads get their notification count and unread + # count updated. + self.db_pool.simple_update_many_txn( + txn, + table="event_push_summary", + key_names=("room_id", "user_id", "thread_id"), + key_values=[(room_id, user_id, row[2]) for row in unread_counts], + value_names=("notif_count", "unread_count"), + value_values=[(row[0], row[1]) for row in unread_counts], + ) + # We always update `event_push_summary_last_receipt_stream_id` to # ensure that we don't rescan the same receipts for remote users. @@ -1204,23 +1242,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # Calculate the new counts that should be upserted into event_push_summary sql = """ - SELECT user_id, room_id, + SELECT user_id, room_id, thread_id, coalesce(old.%s, 0) + upd.cnt, upd.stream_ordering FROM ( - SELECT user_id, room_id, count(*) as cnt, + SELECT user_id, room_id, thread_id, count(*) as cnt, max(ea.stream_ordering) as stream_ordering FROM event_push_actions AS ea - LEFT JOIN event_push_summary AS old USING (user_id, room_id) + LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id) WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ? AND ( old.last_receipt_stream_ordering IS NULL OR old.last_receipt_stream_ordering < ea.stream_ordering ) AND %s = 1 - GROUP BY user_id, room_id + GROUP BY user_id, room_id, thread_id ) AS upd - LEFT JOIN event_push_summary AS old USING (user_id, room_id) + LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id) """ # First get the count of unread messages. @@ -1234,11 +1272,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # object because we might not have the same amount of rows in each of them. To do # this, we use a dict indexed on the user ID and room ID to make it easier to # populate. - summaries: Dict[Tuple[str, str], _EventPushSummary] = {} + summaries: Dict[Tuple[str, str, str], _EventPushSummary] = {} for row in txn: - summaries[(row[0], row[1])] = _EventPushSummary( - unread_count=row[2], - stream_ordering=row[3], + summaries[(row[0], row[1], row[2])] = _EventPushSummary( + unread_count=row[3], + stream_ordering=row[4], notif_count=0, ) @@ -1249,34 +1287,35 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) for row in txn: - if (row[0], row[1]) in summaries: - summaries[(row[0], row[1])].notif_count = row[2] + if (row[0], row[1], row[2]) in summaries: + summaries[(row[0], row[1], row[2])].notif_count = row[3] else: # Because the rules on notifying are different than the rules on marking # a message unread, we might end up with messages that notify but aren't # marked unread, so we might not have a summary for this (user, room) # tuple to complete. - summaries[(row[0], row[1])] = _EventPushSummary( + summaries[(row[0], row[1], row[2])] = _EventPushSummary( unread_count=0, - stream_ordering=row[3], - notif_count=row[2], + stream_ordering=row[4], + notif_count=row[3], ) logger.info("Rotating notifications, handling %d rows", len(summaries)) - # TODO(threads): Update on a per-thread basis. self.db_pool.simple_upsert_many_txn( txn, table="event_push_summary", - key_names=("user_id", "room_id"), - key_values=[(user_id, room_id) for user_id, room_id in summaries], - value_names=("notif_count", "unread_count", "stream_ordering", "thread_id"), + key_names=("user_id", "room_id", "thread_id"), + key_values=[ + (user_id, room_id, thread_id) + for user_id, room_id, thread_id in summaries + ], + value_names=("notif_count", "unread_count", "stream_ordering"), value_values=[ ( summary.notif_count, summary.unread_count, summary.stream_ordering, - "main", ) for summary in summaries.values() ], @@ -1288,7 +1327,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) async def _remove_old_push_actions_that_have_rotated(self) -> None: - """Clear out old push actions that have been summarised.""" + """ + Clear out old push actions that have been summarised (and are older than + 1 day ago). + """ # We want to clear out anything that is older than a day that *has* already # been rotated. diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 4a5c947699..19dbf2da7f 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -90,9 +90,9 @@ Changes in SCHEMA_VERSION = 73; SCHEMA_COMPAT_VERSION = ( - # The groups tables are no longer accessible, so synapses with SCHEMA_VERSION < 72 - # could break. - 72 + # The threads_id column must exist for event_push_actions, event_push_summary, + # receipts_linearized, and receipts_graph. + 73 ) """Limit on how far the synapse codebase can be rolled back without breaking db compat diff --git a/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql b/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql new file mode 100644 index 0000000000..0ffde9bbeb --- /dev/null +++ b/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql @@ -0,0 +1,29 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Forces the background updates from 06thread_notifications.sql to run in the +-- foreground as code will now require those to be "done". + +DELETE FROM background_updates WHERE update_name = 'event_push_backfill_thread_id'; + +-- Overwrite any null thread_id columns. +UPDATE event_push_actions_staging SET thread_id = 'main' WHERE thread_id IS NULL; +UPDATE event_push_actions SET thread_id = 'main' WHERE thread_id IS NULL; +UPDATE event_push_summary SET thread_id = 'main' WHERE thread_id IS NULL; + +-- Do not run the event_push_summary_unique_index job if it is pending; the +-- thread_id field will be made required. +DELETE FROM background_updates WHERE update_name = 'event_push_summary_unique_index'; +DROP INDEX IF EXISTS event_push_summary_unique_index; diff --git a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres new file mode 100644 index 0000000000..33674f8c62 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres @@ -0,0 +1,19 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The columns can now be made non-nullable. +ALTER TABLE event_push_actions_staging ALTER COLUMN thread_id SET NOT NULL; +ALTER TABLE event_push_actions ALTER COLUMN thread_id SET NOT NULL; +ALTER TABLE event_push_summary ALTER COLUMN thread_id SET NOT NULL; diff --git a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite new file mode 100644 index 0000000000..5322ad77a4 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite @@ -0,0 +1,101 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- SQLite doesn't support modifying columns to an existing table, so it must +-- be recreated. + +-- Create the new tables. +CREATE TABLE event_push_actions_staging_new ( + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + actions TEXT NOT NULL, + notif SMALLINT NOT NULL, + highlight SMALLINT NOT NULL, + unread SMALLINT, + thread_id TEXT NOT NULL, + inserted_ts BIGINT +); + +CREATE TABLE event_push_actions_new ( + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + profile_tag VARCHAR(32), + actions TEXT NOT NULL, + topological_ordering BIGINT, + stream_ordering BIGINT, + notif SMALLINT, + highlight SMALLINT, + unread SMALLINT, + thread_id TEXT NOT NULL, + CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag) +); + +CREATE TABLE event_push_summary_new ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notif_count BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL, + unread_count BIGINT, + last_receipt_stream_ordering BIGINT, + thread_id TEXT NOT NULL +); + +-- Swap the indexes. +DROP INDEX IF EXISTS event_push_actions_staging_id; +CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging_new(event_id); + +DROP INDEX IF EXISTS event_push_actions_room_id_user_id; +DROP INDEX IF EXISTS event_push_actions_rm_tokens; +DROP INDEX IF EXISTS event_push_actions_stream_ordering; +DROP INDEX IF EXISTS event_push_actions_u_highlight; +DROP INDEX IF EXISTS event_push_actions_highlights_index; +CREATE INDEX event_push_actions_room_id_user_id on event_push_actions_new(room_id, user_id); +CREATE INDEX event_push_actions_rm_tokens on event_push_actions_new( user_id, room_id, topological_ordering, stream_ordering ); +CREATE INDEX event_push_actions_stream_ordering on event_push_actions_new( stream_ordering, user_id ); +CREATE INDEX event_push_actions_u_highlight ON event_push_actions_new (user_id, stream_ordering); +CREATE INDEX event_push_actions_highlights_index ON event_push_actions_new (user_id, room_id, topological_ordering, stream_ordering); + +-- Copy the data. +INSERT INTO event_push_actions_staging_new (event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts) + SELECT event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts + FROM event_push_actions_staging; + +INSERT INTO event_push_actions_new (room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id) + SELECT room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id + FROM event_push_actions; + +INSERT INTO event_push_summary_new (user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id) + SELECT user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id + FROM event_push_summary; + +-- Drop the old tables. +DROP TABLE event_push_actions_staging; +DROP TABLE event_push_actions; +DROP TABLE event_push_summary; + +-- Rename the tables. +ALTER TABLE event_push_actions_staging_new RENAME TO event_push_actions_staging; +ALTER TABLE event_push_actions_new RENAME TO event_push_actions; +ALTER TABLE event_push_summary_new RENAME TO event_push_summary; + +-- Re-run background updates from 72/02event_push_actions_index.sql and +-- 72/06thread_notifications.sql. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7307, 'event_push_summary_unique_index2', '{}') + ON CONFLICT (update_name) DO NOTHING; +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7307, 'event_push_actions_stream_highlight_index', '{}') + ON CONFLICT (update_name) DO NOTHING; diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index efd92793c0..d42e36cdf1 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -22,7 +22,10 @@ from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.storage.databases.main.event_push_actions import ( + NotifCounts, + RoomNotifCounts, +) from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser from synapse.types import PersistedEventPosition @@ -178,7 +181,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=0, unread_count=0, notify_count=0), + RoomNotifCounts( + NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {} + ), ) self.persist( @@ -191,7 +196,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=0, unread_count=0, notify_count=1), + RoomNotifCounts( + NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {} + ), ) self.persist( @@ -206,7 +213,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=1, unread_count=0, notify_count=2), + RoomNotifCounts( + NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {} + ), ) def test_get_rooms_for_user_with_stream_ordering(self): diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 473c965e19..89f986ac34 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Optional, Tuple from twisted.test.proto_helpers import MemoryReactor @@ -20,6 +20,7 @@ from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.types import JsonDict from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -133,13 +134,14 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): ) ) self.assertEqual( - counts, + counts.main_timeline, NotifCounts( notify_count=noitf_count, unread_count=0, highlight_count=highlight_count, ), ) + self.assertEqual(counts.threads, {}) def _create_event(highlight: bool = False) -> str: result = self.helper.send_event( @@ -186,6 +188,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _assert_counts(0, 0) _create_event() + _assert_counts(1, 0) _rotate() _assert_counts(1, 0) @@ -236,6 +239,168 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _rotate() _assert_counts(0, 0) + def test_count_aggregation_threads(self) -> None: + """ + This is essentially the same test as test_count_aggregation, but adds + events to the main timeline and to a thread. + """ + + user_id, token, _, other_token, room_id = self._create_users_and_room() + thread_id: str + + last_event_id: str + + def _assert_counts( + noitf_count: int, + highlight_count: int, + thread_notif_count: int, + thread_highlight_count: int, + ) -> None: + counts = self.get_success( + self.store.db_pool.runInteraction( + "get-unread-counts", + self.store._get_unread_counts_by_receipt_txn, + room_id, + user_id, + ) + ) + self.assertEqual( + counts.main_timeline, + NotifCounts( + notify_count=noitf_count, + unread_count=0, + highlight_count=highlight_count, + ), + ) + if thread_notif_count or thread_highlight_count: + self.assertEqual( + counts.threads, + { + thread_id: NotifCounts( + notify_count=thread_notif_count, + unread_count=0, + highlight_count=thread_highlight_count, + ), + }, + ) + else: + self.assertEqual(counts.threads, {}) + + def _create_event( + highlight: bool = False, thread_id: Optional[str] = None + ) -> str: + content: JsonDict = { + "msgtype": "m.text", + "body": user_id if highlight else "msg", + } + if thread_id: + content["m.relates_to"] = { + "rel_type": "m.thread", + "event_id": thread_id, + } + + result = self.helper.send_event( + room_id, + type="m.room.message", + content=content, + tok=other_token, + ) + nonlocal last_event_id + last_event_id = result["event_id"] + return last_event_id + + def _rotate() -> None: + self.get_success(self.store._rotate_notifs()) + + def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None: + self.get_success( + self.store.insert_receipt( + room_id, + "m.read", + user_id=user_id, + event_ids=[event_id], + thread_id=thread_id, + data={}, + ) + ) + + _assert_counts(0, 0, 0, 0) + thread_id = _create_event() + _assert_counts(1, 0, 0, 0) + _rotate() + _assert_counts(1, 0, 0, 0) + + _create_event(thread_id=thread_id) + _assert_counts(1, 0, 1, 0) + _rotate() + _assert_counts(1, 0, 1, 0) + + _create_event() + _assert_counts(2, 0, 1, 0) + _rotate() + _assert_counts(2, 0, 1, 0) + + event_id = _create_event(thread_id=thread_id) + _assert_counts(2, 0, 2, 0) + _rotate() + _assert_counts(2, 0, 2, 0) + + _create_event() + _create_event(thread_id=thread_id) + _mark_read(event_id) + _assert_counts(1, 0, 1, 0) + + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + + _create_event() + _create_event(thread_id=thread_id) + _assert_counts(1, 0, 1, 0) + _rotate() + _assert_counts(1, 0, 1, 0) + + # Delete old event push actions, this should not affect the (summarised) count. + self.get_success(self.store._remove_old_push_actions_that_have_rotated()) + _assert_counts(1, 0, 1, 0) + + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + + _create_event(True) + _assert_counts(1, 1, 0, 0) + _rotate() + _assert_counts(1, 1, 0, 0) + + event_id = _create_event(True, thread_id) + _assert_counts(1, 1, 1, 1) + _rotate() + _assert_counts(1, 1, 1, 1) + + # Check that adding another notification and rotating after highlight + # works. + _create_event() + _rotate() + _assert_counts(2, 1, 1, 1) + + _create_event(thread_id=thread_id) + _rotate() + _assert_counts(2, 1, 2, 1) + + # Check that sending read receipts at different points results in the + # right counts. + _mark_read(event_id) + _assert_counts(1, 0, 1, 0) + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + + _create_event(True) + _create_event(True, thread_id) + _assert_counts(1, 1, 1, 1) + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + _rotate() + _assert_counts(0, 0, 0, 0) + def test_find_first_stream_ordering_after_ts(self) -> None: def add_event(so: int, ts: int) -> None: self.get_success( -- cgit 1.5.1 From a7ba457b2b967ca098792d742bc304604b1824b7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Oct 2022 10:46:42 -0400 Subject: Mark events as read using threaded read receipts from MSC3771. (#13877) Applies the proper logic for unthreaded and threaded receipts to either apply to all events in the room or only events in the same thread, respectively. --- changelog.d/13877.feature | 1 + .../storage/databases/main/event_push_actions.py | 277 ++++++++++++++++----- .../73/08thread_receipts_non_null.sql.postgres | 23 ++ .../delta/73/08thread_receipts_non_null.sql.sqlite | 76 ++++++ tests/storage/test_event_push_actions.py | 189 +++++++++++++- 5 files changed, 504 insertions(+), 62 deletions(-) create mode 100644 changelog.d/13877.feature create mode 100644 synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.postgres create mode 100644 synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.sqlite (limited to 'tests/storage') diff --git a/changelog.d/13877.feature b/changelog.d/13877.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13877.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 3210e9cca1..7469cd336c 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -421,7 +421,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn: LoggingTransaction, room_id: str, user_id: str, - receipt_stream_ordering: int, + unthreaded_receipt_stream_ordering: int, ) -> RoomNotifCounts: """Get the number of unread messages for a user/room that have happened since the given stream ordering. @@ -430,9 +430,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn: The database transaction. room_id: The room ID to get unread counts for. user_id: The user ID to get unread counts for. - receipt_stream_ordering: The stream ordering of the user's latest - receipt in the room. If there are no receipts, the stream ordering - of the user's join event. + unthreaded_receipt_stream_ordering: The stream ordering of the user's latest + unthreaded receipt in the room. If there are no unthreaded receipts, + the stream ordering of the user's join event. Returns: A RoomNotifCounts object containing the notification count, the @@ -448,71 +448,181 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return main_counts return thread_counts.setdefault(thread_id, NotifCounts()) + receipt_types_clause, receipts_args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), + ) + # First we pull the counts from the summary table. # - # We check that `last_receipt_stream_ordering` matches the stream - # ordering given. If it doesn't match then a new read receipt has arrived and - # we haven't yet updated the counts in `event_push_summary` to reflect - # that; in that case we simply ignore `event_push_summary` counts - # and do a manual count of all of the rows in the `event_push_actions` table - # for this user/room. + # We check that `last_receipt_stream_ordering` matches the stream ordering of the + # latest receipt for the thread (which may be either the unthreaded read receipt + # or the threaded read receipt). # - # If `last_receipt_stream_ordering` is null then that means it's up to - # date (as the row was written by an older version of Synapse that + # If it doesn't match then a new read receipt has arrived and we haven't yet + # updated the counts in `event_push_summary` to reflect that; in that case we + # simply ignore `event_push_summary` counts. + # + # We then do a manual count of all the rows in the `event_push_actions` table + # for any user/room/thread which did not have a valid summary found. + # + # If `last_receipt_stream_ordering` is null then that means it's up-to-date + # (as the row was written by an older version of Synapse that # updated `event_push_summary` synchronously when persisting a new read # receipt). txn.execute( - """ - SELECT stream_ordering, notif_count, COALESCE(unread_count, 0), thread_id + f""" + SELECT notif_count, COALESCE(unread_count, 0), thread_id FROM event_push_summary + LEFT JOIN ( + SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering + FROM receipts_linearized + LEFT JOIN events USING (room_id, event_id) + WHERE + user_id = ? + AND room_id = ? + AND stream_ordering > ? + AND {receipt_types_clause} + GROUP BY thread_id + ) AS receipts USING (thread_id) WHERE room_id = ? AND user_id = ? AND ( - (last_receipt_stream_ordering IS NULL AND stream_ordering > ?) - OR last_receipt_stream_ordering = ? + (last_receipt_stream_ordering IS NULL AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?)) + OR last_receipt_stream_ordering = COALESCE(threaded_receipt_stream_ordering, ?) ) AND (notif_count != 0 OR COALESCE(unread_count, 0) != 0) """, - (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering), + ( + user_id, + room_id, + unthreaded_receipt_stream_ordering, + *receipts_args, + room_id, + user_id, + unthreaded_receipt_stream_ordering, + unthreaded_receipt_stream_ordering, + ), ) - max_summary_stream_ordering = 0 - for summary_stream_ordering, notif_count, unread_count, thread_id in txn: + summarised_threads = set() + for notif_count, unread_count, thread_id in txn: + summarised_threads.add(thread_id) counts = _get_thread(thread_id) counts.notify_count += notif_count counts.unread_count += unread_count - # Summaries will only be used if they have not been invalidated by - # a recent receipt; track the latest stream ordering or a valid summary. - # - # Note that since there's only one read receipt in the room per user, - # valid summaries are contiguous. - max_summary_stream_ordering = max( - summary_stream_ordering, max_summary_stream_ordering - ) - # Next we need to count highlights, which aren't summarised - sql = """ + sql = f""" SELECT COUNT(*), thread_id FROM event_push_actions + LEFT JOIN ( + SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering + FROM receipts_linearized + LEFT JOIN events USING (room_id, event_id) + WHERE + user_id = ? + AND room_id = ? + AND stream_ordering > ? + AND {receipt_types_clause} + GROUP BY thread_id + ) AS receipts USING (thread_id) WHERE user_id = ? AND room_id = ? - AND stream_ordering > ? + AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?) AND highlight = 1 GROUP BY thread_id """ - txn.execute(sql, (user_id, room_id, receipt_stream_ordering)) + txn.execute( + sql, + ( + user_id, + room_id, + unthreaded_receipt_stream_ordering, + *receipts_args, + user_id, + room_id, + unthreaded_receipt_stream_ordering, + ), + ) for highlight_count, thread_id in txn: _get_thread(thread_id).highlight_count += highlight_count + # For threads which were summarised we need to count actions since the last + # rotation. + thread_id_clause, thread_id_args = make_in_list_sql_clause( + self.database_engine, "thread_id", summarised_threads + ) + + # The (inclusive) event stream ordering that was previously summarised. + rotated_upto_stream_ordering = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + unread_counts = self._get_notif_unread_count_for_user_room( + txn, room_id, user_id, rotated_upto_stream_ordering + ) + for notif_count, unread_count, thread_id in unread_counts: + if thread_id not in summarised_threads: + continue + + if thread_id == MAIN_TIMELINE: + counts.notify_count += notif_count + counts.unread_count += unread_count + elif thread_id in thread_counts: + thread_counts[thread_id].notify_count += notif_count + thread_counts[thread_id].unread_count += unread_count + else: + # Previous thread summaries of 0 are discarded above. + # + # TODO If empty summaries are deleted this can be removed. + thread_counts[thread_id] = NotifCounts( + notify_count=notif_count, + unread_count=unread_count, + highlight_count=0, + ) + # Finally we need to count push actions that aren't included in the # summary returned above. This might be due to recent events that haven't # been summarised yet or the summary is out of date due to a recent read # receipt. - start_unread_stream_ordering = max( - receipt_stream_ordering, max_summary_stream_ordering - ) - unread_counts = self._get_notif_unread_count_for_user_room( - txn, room_id, user_id, start_unread_stream_ordering + sql = f""" + SELECT + COUNT(CASE WHEN notif = 1 THEN 1 END), + COUNT(CASE WHEN unread = 1 THEN 1 END), + thread_id + FROM event_push_actions + LEFT JOIN ( + SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering + FROM receipts_linearized + LEFT JOIN events USING (room_id, event_id) + WHERE + user_id = ? + AND room_id = ? + AND stream_ordering > ? + AND {receipt_types_clause} + GROUP BY thread_id + ) AS receipts USING (thread_id) + WHERE user_id = ? + AND room_id = ? + AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?) + AND NOT {thread_id_clause} + GROUP BY thread_id + """ + txn.execute( + sql, + ( + user_id, + room_id, + unthreaded_receipt_stream_ordering, + *receipts_args, + user_id, + room_id, + unthreaded_receipt_stream_ordering, + *thread_id_args, + ), ) - - for notif_count, unread_count, thread_id in unread_counts: + for notif_count, unread_count, thread_id in txn: counts = _get_thread(thread_id) counts.notify_count += notif_count counts.unread_count += unread_count @@ -526,6 +636,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: str, stream_ordering: int, max_stream_ordering: Optional[int] = None, + thread_id: Optional[str] = None, ) -> List[Tuple[int, int, str]]: """Returns the notify and unread counts from `event_push_actions` for the given user/room in the given range. @@ -540,6 +651,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas stream_ordering: The (exclusive) minimum stream ordering to consider. max_stream_ordering: The (inclusive) maximum stream ordering to consider. If this is not given, then no maximum is applied. + thread_id: The thread ID to fetch unread counts for. If this is not provided + then the results for *all* threads is returned. + + Note that if this is provided the resulting list will only have 0 or + 1 tuples in it. Return: A tuple of the notif count and unread count in the given range for @@ -551,10 +667,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering): return [] - clause = "" + stream_ordering_clause = "" args = [user_id, room_id, stream_ordering] if max_stream_ordering is not None: - clause = "AND ea.stream_ordering <= ?" + stream_ordering_clause = "AND ea.stream_ordering <= ?" args.append(max_stream_ordering) # If the max stream ordering is less than the min stream ordering, @@ -562,6 +678,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas if max_stream_ordering <= stream_ordering: return [] + # Either limit the results to a specific thread or fetch all threads. + thread_id_clause = "" + if thread_id is not None: + thread_id_clause = "AND thread_id = ?" + args.append(thread_id) + sql = f""" SELECT COUNT(CASE WHEN notif = 1 THEN 1 END), @@ -571,7 +693,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas WHERE user_id = ? AND room_id = ? AND ea.stream_ordering > ? - {clause} + {stream_ordering_clause} + {thread_id_clause} GROUP BY thread_id """ @@ -1086,7 +1209,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) sql = """ - SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering + SELECT r.stream_id, r.room_id, r.user_id, r.thread_id, e.stream_ordering FROM receipts_linearized AS r INNER JOIN events AS e USING (event_id) WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ? @@ -1107,45 +1230,69 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas limit, ), ) - rows = cast(List[Tuple[int, str, str, int]], txn.fetchall()) + rows = cast(List[Tuple[int, str, str, Optional[str], int]], txn.fetchall()) # For each new read receipt we delete push actions from before it and # recalculate the summary. - for _, room_id, user_id, stream_ordering in rows: + # + # Care must be taken of whether it is a threaded or unthreaded receipt. + for _, room_id, user_id, thread_id, stream_ordering in rows: # Only handle our own read receipts. if not self.hs.is_mine_id(user_id): continue + thread_clause = "" + thread_args: Tuple = () + if thread_id is not None: + thread_clause = "AND thread_id = ?" + thread_args = (thread_id,) + + # For each new read receipt we delete push actions from before it and + # recalculate the summary. txn.execute( - """ + f""" DELETE FROM event_push_actions WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? AND highlight = 0 + {thread_clause} """, - (room_id, user_id, stream_ordering), + (room_id, user_id, stream_ordering, *thread_args), ) # Fetch the notification counts between the stream ordering of the # latest receipt and what was previously summarised. unread_counts = self._get_notif_unread_count_for_user_room( - txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering - ) - - # First mark the summary for all threads in the room as cleared. - self.db_pool.simple_update_txn( txn, - table="event_push_summary", - keyvalues={"user_id": user_id, "room_id": room_id}, - updatevalues={ - "notif_count": 0, - "unread_count": 0, - "stream_ordering": old_rotate_stream_ordering, - "last_receipt_stream_ordering": stream_ordering, - }, + room_id, + user_id, + stream_ordering, + old_rotate_stream_ordering, + thread_id, ) + # For an unthreaded receipt, mark the summary for all threads in the room + # as cleared. + if thread_id is None: + self.db_pool.simple_update_txn( + txn, + table="event_push_summary", + keyvalues={"user_id": user_id, "room_id": room_id}, + updatevalues={ + "notif_count": 0, + "unread_count": 0, + "stream_ordering": old_rotate_stream_ordering, + "last_receipt_stream_ordering": stream_ordering, + }, + ) + + # For a threaded receipt, we *always* want to update that receipt, + # event if there are no new notifications in that thread. This ensures + # the stream_ordering & last_receipt_stream_ordering are updated. + elif not unread_counts: + unread_counts = [(0, 0, thread_id)] + # Then any updated threads get their notification count and unread # count updated. self.db_pool.simple_update_many_txn( @@ -1153,8 +1300,16 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas table="event_push_summary", key_names=("room_id", "user_id", "thread_id"), key_values=[(room_id, user_id, row[2]) for row in unread_counts], - value_names=("notif_count", "unread_count"), - value_values=[(row[0], row[1]) for row in unread_counts], + value_names=( + "notif_count", + "unread_count", + "stream_ordering", + "last_receipt_stream_ordering", + ), + value_values=[ + (row[0], row[1], old_rotate_stream_ordering, stream_ordering) + for row in unread_counts + ], ) # We always update `event_push_summary_last_receipt_stream_id` to diff --git a/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.postgres b/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.postgres new file mode 100644 index 0000000000..3e0bc9e5eb --- /dev/null +++ b/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.postgres @@ -0,0 +1,23 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Drop constraint on (room_id, receipt_type, user_id). + +-- Rebuild the unique constraint with the thread_id. +ALTER TABLE receipts_linearized + DROP CONSTRAINT receipts_linearized_uniqueness; + +ALTER TABLE receipts_graph + DROP CONSTRAINT receipts_graph_uniqueness; diff --git a/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.sqlite b/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.sqlite new file mode 100644 index 0000000000..e664889fbc --- /dev/null +++ b/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.sqlite @@ -0,0 +1,76 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Drop constraint on (room_id, receipt_type, user_id). +-- +-- SQLite doesn't support modifying constraints to an existing table, so it must +-- be recreated. + +-- Create the new tables. +CREATE TABLE receipts_linearized_new ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + thread_id TEXT, + event_stream_ordering BIGINT, + data TEXT NOT NULL, + CONSTRAINT receipts_linearized_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id) +); + +CREATE TABLE receipts_graph_new ( + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_ids TEXT NOT NULL, + thread_id TEXT, + data TEXT NOT NULL, + CONSTRAINT receipts_graph_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id) +); + +-- Drop the old indexes. +DROP INDEX IF EXISTS receipts_linearized_id; +DROP INDEX IF EXISTS receipts_linearized_room_stream; +DROP INDEX IF EXISTS receipts_linearized_user; + +-- Copy the data. +INSERT INTO receipts_linearized_new (stream_id, room_id, receipt_type, user_id, event_id, data) + SELECT stream_id, room_id, receipt_type, user_id, event_id, data + FROM receipts_linearized; +INSERT INTO receipts_graph_new (room_id, receipt_type, user_id, event_ids, data) + SELECT room_id, receipt_type, user_id, event_ids, data + FROM receipts_graph; + +-- Drop the old tables. +DROP TABLE receipts_linearized; +DROP TABLE receipts_graph; + +-- Rename the tables. +ALTER TABLE receipts_linearized_new RENAME TO receipts_linearized; +ALTER TABLE receipts_graph_new RENAME TO receipts_graph; + +-- Create the indices. +CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id ); +CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id ); +CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id ); + +-- Re-run background updates from 72/08thread_receipts.sql. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7308, 'receipts_linearized_unique_index', '{}') + ON CONFLICT (update_name) DO NOTHING; +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7308, 'receipts_graph_unique_index', '{}') + ON CONFLICT (update_name) DO NOTHING; diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 89f986ac34..6fa0cafb75 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -16,6 +16,7 @@ from typing import Optional, Tuple from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import MAIN_TIMELINE from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer @@ -312,7 +313,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): def _rotate() -> None: self.get_success(self.store._rotate_notifs()) - def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None: + def _mark_read(event_id: str, thread_id: str = MAIN_TIMELINE) -> None: self.get_success( self.store.insert_receipt( room_id, @@ -348,9 +349,12 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _create_event() _create_event(thread_id=thread_id) _mark_read(event_id) + _assert_counts(1, 0, 3, 0) + _mark_read(event_id, thread_id) _assert_counts(1, 0, 1, 0) _mark_read(last_event_id) + _mark_read(last_event_id, thread_id) _assert_counts(0, 0, 0, 0) _create_event() @@ -364,6 +368,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _assert_counts(1, 0, 1, 0) _mark_read(last_event_id) + _mark_read(last_event_id, thread_id) _assert_counts(0, 0, 0, 0) _create_event(True) @@ -389,8 +394,190 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): # Check that sending read receipts at different points results in the # right counts. _mark_read(event_id) + _assert_counts(1, 0, 2, 1) + _mark_read(event_id, thread_id) _assert_counts(1, 0, 1, 0) _mark_read(last_event_id) + _assert_counts(0, 0, 1, 0) + _mark_read(last_event_id, thread_id) + _assert_counts(0, 0, 0, 0) + + _create_event(True) + _create_event(True, thread_id) + _assert_counts(1, 1, 1, 1) + _mark_read(last_event_id) + _mark_read(last_event_id, thread_id) + _assert_counts(0, 0, 0, 0) + _rotate() + _assert_counts(0, 0, 0, 0) + + def test_count_aggregation_mixed(self) -> None: + """ + This is essentially the same test as test_count_aggregation_threads, but + sends both unthreaded and threaded receipts. + """ + + # Create a user to receive notifications and send receipts. + user_id = self.register_user("user1235", "pass") + token = self.login("user1235", "pass") + + # And another users to send events. + other_id = self.register_user("other", "pass") + other_token = self.login("other", "pass") + + # Create a room and put both users in it. + room_id = self.helper.create_room_as(user_id, tok=token) + self.helper.join(room_id, other_id, tok=other_token) + thread_id: str + + last_event_id: str + + def _assert_counts( + noitf_count: int, + highlight_count: int, + thread_notif_count: int, + thread_highlight_count: int, + ) -> None: + counts = self.get_success( + self.store.db_pool.runInteraction( + "get-unread-counts", + self.store._get_unread_counts_by_receipt_txn, + room_id, + user_id, + ) + ) + self.assertEqual( + counts.main_timeline, + NotifCounts( + notify_count=noitf_count, + unread_count=0, + highlight_count=highlight_count, + ), + ) + if thread_notif_count or thread_highlight_count: + self.assertEqual( + counts.threads, + { + thread_id: NotifCounts( + notify_count=thread_notif_count, + unread_count=0, + highlight_count=thread_highlight_count, + ), + }, + ) + else: + self.assertEqual(counts.threads, {}) + + def _create_event( + highlight: bool = False, thread_id: Optional[str] = None + ) -> str: + content: JsonDict = { + "msgtype": "m.text", + "body": user_id if highlight else "msg", + } + if thread_id: + content["m.relates_to"] = { + "rel_type": "m.thread", + "event_id": thread_id, + } + + result = self.helper.send_event( + room_id, + type="m.room.message", + content=content, + tok=other_token, + ) + nonlocal last_event_id + last_event_id = result["event_id"] + return last_event_id + + def _rotate() -> None: + self.get_success(self.store._rotate_notifs()) + + def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None: + self.get_success( + self.store.insert_receipt( + room_id, + "m.read", + user_id=user_id, + event_ids=[event_id], + thread_id=thread_id, + data={}, + ) + ) + + _assert_counts(0, 0, 0, 0) + thread_id = _create_event() + _assert_counts(1, 0, 0, 0) + _rotate() + _assert_counts(1, 0, 0, 0) + + _create_event(thread_id=thread_id) + _assert_counts(1, 0, 1, 0) + _rotate() + _assert_counts(1, 0, 1, 0) + + _create_event() + _assert_counts(2, 0, 1, 0) + _rotate() + _assert_counts(2, 0, 1, 0) + + event_id = _create_event(thread_id=thread_id) + _assert_counts(2, 0, 2, 0) + _rotate() + _assert_counts(2, 0, 2, 0) + + _create_event() + _create_event(thread_id=thread_id) + _mark_read(event_id) + _assert_counts(1, 0, 1, 0) + + _mark_read(last_event_id, MAIN_TIMELINE) + _mark_read(last_event_id, thread_id) + _assert_counts(0, 0, 0, 0) + + _create_event() + _create_event(thread_id=thread_id) + _assert_counts(1, 0, 1, 0) + _rotate() + _assert_counts(1, 0, 1, 0) + + # Delete old event push actions, this should not affect the (summarised) count. + self.get_success(self.store._remove_old_push_actions_that_have_rotated()) + _assert_counts(1, 0, 1, 0) + + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + + _create_event(True) + _assert_counts(1, 1, 0, 0) + _rotate() + _assert_counts(1, 1, 0, 0) + + event_id = _create_event(True, thread_id) + _assert_counts(1, 1, 1, 1) + _rotate() + _assert_counts(1, 1, 1, 1) + + # Check that adding another notification and rotating after highlight + # works. + _create_event() + _rotate() + _assert_counts(2, 1, 1, 1) + + _create_event(thread_id=thread_id) + _rotate() + _assert_counts(2, 1, 2, 1) + + # Check that sending read receipts at different points results in the + # right counts. + _mark_read(event_id) + _assert_counts(1, 0, 1, 0) + _mark_read(event_id, MAIN_TIMELINE) + _assert_counts(1, 0, 1, 0) + _mark_read(last_event_id, MAIN_TIMELINE) + _assert_counts(0, 0, 1, 0) + _mark_read(last_event_id, thread_id) _assert_counts(0, 0, 0, 0) _create_event(True) -- cgit 1.5.1 From 2b6d41ebd685fb546e52acdbcb0024dfcf5a5db1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Oct 2022 11:36:16 -0400 Subject: Recursively fetch the thread for receipts & notifications. (#13824) Consider an event to be part of a thread if you can follow a chain of relations up to a thread root. Part of MSC3773 & MSC3771. --- changelog.d/13824.feature | 1 + synapse/push/bulk_push_rule_evaluator.py | 5 ++ synapse/rest/client/receipts.py | 22 +++++- synapse/storage/databases/main/relations.py | 36 ++++++++++ tests/storage/test_event_push_actions.py | 100 ++++++++++++++++++++++++++++ 5 files changed, 162 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13824.feature (limited to 'tests/storage') diff --git a/changelog.d/13824.feature b/changelog.d/13824.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13824.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 61d952742d..f8c4dd74f0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -286,8 +286,13 @@ class BulkPushRuleEvaluator: relation.parent_id, itertools.chain(*(r.rules() for r in rules_by_user.values())), ) + # Recursively attempt to find the thread this event relates to. if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id + else: + # Since the event has not yet been persisted we check whether + # the parent is part of a thread. + thread_id = await self.store.get_thread_id(relation.parent_id) or "main" evaluator = PushRuleEvaluator( _flatten_dict(event), diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index f3ff156abe..287dfdd69e 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -16,7 +16,7 @@ import logging from typing import TYPE_CHECKING, Tuple from synapse.api.constants import ReceiptTypes -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest @@ -43,6 +43,7 @@ class ReceiptRestServlet(RestServlet): self.receipts_handler = hs.get_receipts_handler() self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() + self._main_store = hs.get_datastores().main self._known_receipt_types = { ReceiptTypes.READ, @@ -71,7 +72,24 @@ class ReceiptRestServlet(RestServlet): thread_id = body.get("thread_id") if not thread_id or not isinstance(thread_id, str): raise SynapseError( - 400, "thread_id field must be a non-empty string" + 400, + "thread_id field must be a non-empty string", + Codes.INVALID_PARAM, + ) + + if receipt_type == ReceiptTypes.FULLY_READ: + raise SynapseError( + 400, + f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.", + Codes.INVALID_PARAM, + ) + + # Ensure the event ID roughly correlates to the thread ID. + if thread_id != await self._main_store.get_thread_id(event_id): + raise SynapseError( + 400, + f"event_id {event_id} is not related to thread {thread_id}", + Codes.INVALID_PARAM, ) await self.presence_handler.bump_presence_active_time(requester.user) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 898947af95..154385b1e8 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -832,6 +832,42 @@ class RelationsWorkerStore(SQLBaseStore): "get_event_relations", _get_event_relations ) + @cached() + async def get_thread_id(self, event_id: str) -> Optional[str]: + """ + Get the thread ID for an event. This considers multi-level relations, + e.g. an annotation to an event which is part of a thread. + + Args: + event_id: The event ID to fetch the thread ID for. + + Returns: + The event ID of the root event in the thread, if this event is part + of a thread. None, otherwise. + """ + # Since event relations form a tree, we should only ever find 0 or 1 + # results from the below query. + sql = """ + WITH RECURSIVE related_events AS ( + SELECT event_id, relates_to_id, relation_type + FROM event_relations + WHERE event_id = ? + UNION SELECT e.event_id, e.relates_to_id, e.relation_type + FROM event_relations e + INNER JOIN related_events r ON r.relates_to_id = e.event_id + ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread'; + """ + + def _get_thread_id(txn: LoggingTransaction) -> Optional[str]: + txn.execute(sql, (event_id,)) + # TODO Should we ensure there's only a single result here? + row = txn.fetchone() + if row: + return row[0] + return None + + return await self.db_pool.runInteraction("get_thread_id", _get_thread_id) + class RelationsStore(RelationsWorkerStore): pass diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 6fa0cafb75..886585e9f2 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -588,6 +588,106 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _rotate() _assert_counts(0, 0, 0, 0) + def test_recursive_thread(self) -> None: + """ + Events related to events in a thread should still be considered part of + that thread. + """ + + # Create a user to receive notifications and send receipts. + user_id = self.register_user("user1235", "pass") + token = self.login("user1235", "pass") + + # And another users to send events. + other_id = self.register_user("other", "pass") + other_token = self.login("other", "pass") + + # Create a room and put both users in it. + room_id = self.helper.create_room_as(user_id, tok=token) + self.helper.join(room_id, other_id, tok=other_token) + + # Update the user's push rules to care about reaction events. + self.get_success( + self.store.add_push_rule( + user_id, + "related_events", + priority_class=5, + conditions=[ + {"kind": "event_match", "key": "type", "pattern": "m.reaction"} + ], + actions=["notify"], + ) + ) + + def _create_event(type: str, content: JsonDict) -> str: + result = self.helper.send_event( + room_id, type=type, content=content, tok=other_token + ) + return result["event_id"] + + def _assert_counts(noitf_count: int, thread_notif_count: int) -> None: + counts = self.get_success( + self.store.db_pool.runInteraction( + "get-unread-counts", + self.store._get_unread_counts_by_receipt_txn, + room_id, + user_id, + ) + ) + self.assertEqual( + counts.main_timeline, + NotifCounts( + notify_count=noitf_count, unread_count=0, highlight_count=0 + ), + ) + if thread_notif_count: + self.assertEqual( + counts.threads, + { + thread_id: NotifCounts( + notify_count=thread_notif_count, + unread_count=0, + highlight_count=0, + ), + }, + ) + else: + self.assertEqual(counts.threads, {}) + + # Create a root event. + thread_id = _create_event( + "m.room.message", {"msgtype": "m.text", "body": "msg"} + ) + _assert_counts(1, 0) + + # Reply, creating a thread. + reply_id = _create_event( + "m.room.message", + { + "msgtype": "m.text", + "body": "msg", + "m.relates_to": { + "rel_type": "m.thread", + "event_id": thread_id, + }, + }, + ) + _assert_counts(1, 1) + + # Create an event related to a thread event, this should still appear in + # the thread. + _create_event( + type="m.reaction", + content={ + "m.relates_to": { + "rel_type": "m.annotation", + "event_id": reply_id, + "key": "A", + } + }, + ) + _assert_counts(1, 2) + def test_find_first_stream_ordering_after_ts(self) -> None: def add_event(so: int, ts: int) -> None: self.get_success( -- cgit 1.5.1 From dcced5a8d76b94e372aefa7d1f05ec0dbc22ea0d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Oct 2022 12:07:02 -0400 Subject: Use threaded receipts when fetching events for push. (#13878) Update the HTTP and email pushers to consider threaded read receipts when fetching unread events. --- changelog.d/13878.feature | 1 + .../storage/databases/main/event_push_actions.py | 80 +++++++++++++++------- tests/storage/test_event_push_actions.py | 57 ++++++++++----- 3 files changed, 97 insertions(+), 41 deletions(-) create mode 100644 changelog.d/13878.feature (limited to 'tests/storage') diff --git a/changelog.d/13878.feature b/changelog.d/13878.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13878.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 7469cd336c..332e13d1c9 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -119,6 +119,32 @@ DEFAULT_HIGHLIGHT_ACTION: List[Union[dict, str]] = [ ] +@attr.s(slots=True, auto_attribs=True) +class _RoomReceipt: + """ + HttpPushAction instances include the information used to generate HTTP + requests to a push gateway. + """ + + unthreaded_stream_ordering: int = 0 + # threaded_stream_ordering includes the main pseudo-thread. + threaded_stream_ordering: Dict[str, int] = attr.Factory(dict) + + def is_unread(self, thread_id: str, stream_ordering: int) -> bool: + """Returns True if the stream ordering is unread according to the receipt information.""" + + # Only include push actions with a stream ordering after both the unthreaded + # and threaded receipt. Properly handles a user without any receipts present. + return ( + self.unthreaded_stream_ordering < stream_ordering + and self.threaded_stream_ordering.get(thread_id, 0) < stream_ordering + ) + + +# A _RoomReceipt with no receipts in it. +MISSING_ROOM_RECEIPT = _RoomReceipt() + + @attr.s(slots=True, frozen=True, auto_attribs=True) class HttpPushAction: """ @@ -716,7 +742,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def _get_receipts_by_room_txn( self, txn: LoggingTransaction, user_id: str - ) -> Dict[str, int]: + ) -> Dict[str, _RoomReceipt]: """ Generate a map of room ID to the latest stream ordering that has been read by the given user. @@ -726,7 +752,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: The user to fetch receipts for. Returns: - A map of room ID to stream ordering for all rooms the user has a receipt in. + A map including all rooms the user is in with a receipt. It maps + room IDs to _RoomReceipt instances """ receipt_types_clause, args = make_in_list_sql_clause( self.database_engine, @@ -735,20 +762,26 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) sql = f""" - SELECT room_id, MAX(stream_ordering) + SELECT room_id, thread_id, MAX(stream_ordering) FROM receipts_linearized INNER JOIN events USING (room_id, event_id) WHERE {receipt_types_clause} AND user_id = ? - GROUP BY room_id + GROUP BY room_id, thread_id """ args.extend((user_id,)) txn.execute(sql, args) - return { - room_id: latest_stream_ordering - for room_id, latest_stream_ordering in txn.fetchall() - } + + result: Dict[str, _RoomReceipt] = {} + for room_id, thread_id, stream_ordering in txn: + room_receipt = result.setdefault(room_id, _RoomReceipt()) + if thread_id is None: + room_receipt.unthreaded_stream_ordering = stream_ordering + else: + room_receipt.threaded_stream_ordering[thread_id] = stream_ordering + + return result async def get_unread_push_actions_for_user_in_range_for_http( self, @@ -781,9 +814,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def get_push_actions_txn( txn: LoggingTransaction, - ) -> List[Tuple[str, str, int, str, bool]]: + ) -> List[Tuple[str, str, str, int, str, bool]]: sql = """ - SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight + SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering, + ep.actions, ep.highlight FROM event_push_actions AS ep WHERE ep.user_id = ? @@ -793,7 +827,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ORDER BY ep.stream_ordering ASC LIMIT ? """ txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit)) - return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) + return cast(List[Tuple[str, str, str, int, str, bool]], txn.fetchall()) push_actions = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn @@ -806,10 +840,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas stream_ordering=stream_ordering, actions=_deserialize_action(actions, highlight), ) - for event_id, room_id, stream_ordering, actions, highlight in push_actions - # Only include push actions with a stream ordering after any receipt, or without any - # receipt present (invited to but never read rooms). - if stream_ordering > receipts_by_room.get(room_id, 0) + for event_id, room_id, thread_id, stream_ordering, actions, highlight in push_actions + if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread( + thread_id, stream_ordering + ) ] # Now sort it so it's ordered correctly, since currently it will @@ -853,10 +887,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def get_push_actions_txn( txn: LoggingTransaction, - ) -> List[Tuple[str, str, int, str, bool, int]]: + ) -> List[Tuple[str, str, str, int, str, bool, int]]: sql = """ - SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, - ep.highlight, e.received_ts + SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering, + ep.actions, ep.highlight, e.received_ts FROM event_push_actions AS ep INNER JOIN events AS e USING (room_id, event_id) WHERE @@ -867,7 +901,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ORDER BY ep.stream_ordering DESC LIMIT ? """ txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit)) - return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) + return cast(List[Tuple[str, str, str, int, str, bool, int]], txn.fetchall()) push_actions = await self.db_pool.runInteraction( "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn @@ -882,10 +916,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas actions=_deserialize_action(actions, highlight), received_ts=received_ts, ) - for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions - # Only include push actions with a stream ordering after any receipt, or without any - # receipt present (invited to but never read rooms). - if stream_ordering > receipts_by_room.get(room_id, 0) + for event_id, room_id, thread_id, stream_ordering, actions, highlight, received_ts in push_actions + if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread( + thread_id, stream_ordering + ) ] # Now sort it so it's ordered correctly, since currently it will diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 886585e9f2..ee48920f84 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -16,7 +16,7 @@ from typing import Optional, Tuple from twisted.test.proto_helpers import MemoryReactor -from synapse.api.constants import MAIN_TIMELINE +from synapse.api.constants import MAIN_TIMELINE, RelationTypes from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer @@ -66,16 +66,23 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): user_id, token, _, other_token, room_id = self._create_users_and_room() # Create two events, one of which is a highlight. - self.helper.send_event( + first_event_id = self.helper.send_event( room_id, type="m.room.message", content={"msgtype": "m.text", "body": "msg"}, tok=other_token, - ) - event_id = self.helper.send_event( + )["event_id"] + second_event_id = self.helper.send_event( room_id, type="m.room.message", - content={"msgtype": "m.text", "body": user_id}, + content={ + "msgtype": "m.text", + "body": user_id, + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": first_event_id, + }, + }, tok=other_token, )["event_id"] @@ -95,13 +102,13 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): ) self.assertEqual(2, len(email_actions)) - # Send a receipt, which should clear any actions. + # Send a receipt, which should clear the first action. self.get_success( self.store.insert_receipt( room_id, "m.read", user_id=user_id, - event_ids=[event_id], + event_ids=[first_event_id], thread_id=None, data={}, ) @@ -111,6 +118,30 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): user_id, 0, 1000, 20 ) ) + self.assertEqual(1, len(http_actions)) + email_actions = self.get_success( + self.store.get_unread_push_actions_for_user_in_range_for_email( + user_id, 0, 1000, 20 + ) + ) + self.assertEqual(1, len(email_actions)) + + # Send a thread receipt to clear the thread action. + self.get_success( + self.store.insert_receipt( + room_id, + "m.read", + user_id=user_id, + event_ids=[second_event_id], + thread_id=first_event_id, + data={}, + ) + ) + http_actions = self.get_success( + self.store.get_unread_push_actions_for_user_in_range_for_http( + user_id, 0, 1000, 20 + ) + ) self.assertEqual([], http_actions) email_actions = self.get_success( self.store.get_unread_push_actions_for_user_in_range_for_email( @@ -417,17 +448,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): sends both unthreaded and threaded receipts. """ - # Create a user to receive notifications and send receipts. - user_id = self.register_user("user1235", "pass") - token = self.login("user1235", "pass") - - # And another users to send events. - other_id = self.register_user("other", "pass") - other_token = self.login("other", "pass") - - # Create a room and put both users in it. - room_id = self.helper.create_room_as(user_id, tok=token) - self.helper.join(room_id, other_id, tok=other_token) + user_id, token, _, other_token, room_id = self._create_users_and_room() thread_id: str last_event_id: str -- cgit 1.5.1 From d1bdeccb50550ef454067aa01dd9d004c4704633 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 14:05:25 -0400 Subject: Accept threaded receipts for events related to the root event. (#14174) The root node of a thread (and events related to it) are considered "part of a thread" when validating receipts. This allows clients which show the root node in both the main timeline and the threaded timeline to easily send receipts in either. Note that threaded notifications are not created for these events, these events created notifications on the main timeline. --- changelog.d/14174.feature | 1 + synapse/rest/client/receipts.py | 44 ++++++++++- synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/relations.py | 98 ++++++++++++++++++++++-- tests/storage/test_relations.py | 111 ++++++++++++++++++++++++++++ 5 files changed, 247 insertions(+), 8 deletions(-) create mode 100644 changelog.d/14174.feature create mode 100644 tests/storage/test_relations.py (limited to 'tests/storage') diff --git a/changelog.d/14174.feature b/changelog.d/14174.feature new file mode 100644 index 0000000000..5d0ae16e13 --- /dev/null +++ b/changelog.d/14174.feature @@ -0,0 +1 @@ +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 14dec7ac4e..18a282b22c 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -83,7 +83,7 @@ class ReceiptRestServlet(RestServlet): ) # Ensure the event ID roughly correlates to the thread ID. - if thread_id != await self._main_store.get_thread_id(event_id): + if not await self._is_event_in_thread(event_id, thread_id): raise SynapseError( 400, f"event_id {event_id} is not related to thread {thread_id}", @@ -109,6 +109,46 @@ class ReceiptRestServlet(RestServlet): return 200, {} + async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool: + """ + The event must be related to the thread ID (in a vague sense) to ensure + clients aren't sending bogus receipts. + + A thread ID is considered valid for a given event E if: + + 1. E has a thread relation which matches the thread ID; + 2. E has another event which has a thread relation to E matching the + thread ID; or + 3. E is recursively related (via any rel_type) to an event which + satisfies 1 or 2. + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + It is valid to send a receipt for thread A on A, B, C, D, or E. + + It is valid to send a receipt for the main timeline on A, D, and E. + + Args: + event_id: The event ID to check. + thread_id: The thread ID the event is potentially part of. + + Returns: + True if the event belongs to the given thread, otherwise False. + """ + + # If the receipt is on the main timeline, it is enough to check whether + # the event is directly related to a thread. + if thread_id == MAIN_TIMELINE: + return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id) + + # Otherwise, check if the event is directly part of a thread, or is the + # root message (or related to the root message) of a thread. + return thread_id == await self._main_store.get_thread_id_for_receipts(event_id) + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReceiptRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index b47fc606c7..ed0be4abe5 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -245,6 +245,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,)) self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,)) self._attempt_to_invalidate_cache("get_thread_id", (redacts,)) + self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 7c54ce0b2e..1de62ee9df 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -946,6 +946,20 @@ class RelationsWorkerStore(SQLBaseStore): Get the thread ID for an event. This considers multi-level relations, e.g. an annotation to an event which is part of a thread. + It only searches up the relations tree, i.e. it only searches for events + which the given event is related to (and which those events are related + to, etc.) + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id(X) considers events B and C as part of thread A. + + See also get_thread_id_for_receipts. + Args: event_id: The event ID to fetch the thread ID for. @@ -953,22 +967,32 @@ class RelationsWorkerStore(SQLBaseStore): The event ID of the root event in the thread, if this event is part of a thread. "main", otherwise. """ - # Since event relations form a tree, we should only ever find 0 or 1 - # results from the below query. + + # Recurse event relations up to the *root* event, then search that chain + # of relations for a thread relation. If one is found, the root event is + # returned. + # + # Note that this should only ever find 0 or 1 entries since it is invalid + # for an event to have a thread relation to an event which also has a + # relation. sql = """ WITH RECURSIVE related_events AS ( - SELECT event_id, relates_to_id, relation_type + SELECT event_id, relates_to_id, relation_type, 0 depth FROM event_relations WHERE event_id = ? - UNION SELECT e.event_id, e.relates_to_id, e.relation_type + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 FROM event_relations e INNER JOIN related_events r ON r.relates_to_id = e.event_id - ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread'; + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + WHERE relation_type = 'm.thread' + ORDER BY depth DESC + LIMIT 1; """ def _get_thread_id(txn: LoggingTransaction) -> str: txn.execute(sql, (event_id,)) - # TODO Should we ensure there's only a single result here? row = txn.fetchone() if row: return row[0] @@ -978,6 +1002,68 @@ class RelationsWorkerStore(SQLBaseStore): return await self.db_pool.runInteraction("get_thread_id", _get_thread_id) + @cached() + async def get_thread_id_for_receipts(self, event_id: str) -> str: + """ + Get the thread ID for an event by traversing to the top-most related event + and confirming any children events form a thread. + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id_for_receipts(X) considers events A, B, C, D, and E as part + of thread A. + + See also get_thread_id. + + Args: + event_id: The event ID to fetch the thread ID for. + + Returns: + The event ID of the root event in the thread, if this event is part + of a thread. "main", otherwise. + """ + + # Recurse event relations up to the *root* event, then search for any events + # related to that root node for a thread relation. If one is found, the + # root event is returned. + # + # Note that there cannot be thread relations in the middle of the chain since + # it is invalid for an event to have a thread relation to an event which also + # has a relation. + sql = """ + SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE(( + WITH RECURSIVE related_events AS ( + SELECT event_id, relates_to_id, relation_type, 0 depth + FROM event_relations + WHERE event_id = ? + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 + FROM event_relations e + INNER JOIN related_events r ON r.relates_to_id = e.event_id + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + ORDER BY depth DESC + LIMIT 1 + ), ?) AND relation_type = 'm.thread' LIMIT 1; + """ + + def _get_related_thread_id(txn: LoggingTransaction) -> str: + txn.execute(sql, (event_id, event_id)) + row = txn.fetchone() + if row: + return row[0] + + # If no thread was found, it is part of the main timeline. + return MAIN_TIMELINE + + return await self.db_pool.runInteraction( + "get_related_thread_id", _get_related_thread_id + ) + class RelationsStore(RelationsWorkerStore): pass diff --git a/tests/storage/test_relations.py b/tests/storage/test_relations.py new file mode 100644 index 0000000000..cd1d00208b --- /dev/null +++ b/tests/storage/test_relations.py @@ -0,0 +1,111 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import MAIN_TIMELINE +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest + + +class RelationsStoreTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + """ + Creates a DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + F <--[m.annotation]-- G + + """ + self._main_store = self.hs.get_datastores().main + + self._create_relation("A", "B", "m.thread") + self._create_relation("B", "C", "m.annotation") + self._create_relation("A", "D", "m.reference") + self._create_relation("D", "E", "m.annotation") + self._create_relation("F", "G", "m.annotation") + + def _create_relation(self, parent_id: str, event_id: str, rel_type: str) -> None: + self.get_success( + self._main_store.db_pool.simple_insert( + table="event_relations", + values={ + "event_id": event_id, + "relates_to_id": parent_id, + "relation_type": rel_type, + }, + ) + ) + + def test_get_thread_id(self) -> None: + """ + Ensure that get_thread_id only searches up the tree for threads. + """ + # The thread itself and children of it return the thread. + thread_id = self.get_success(self._main_store.get_thread_id("B")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("C")) + self.assertEqual("A", thread_id) + + # But the root and events related to the root do not. + thread_id = self.get_success(self._main_store.get_thread_id("A")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("D")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("E")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + # Events which are not related to a thread at all should return the + # main timeline. + thread_id = self.get_success(self._main_store.get_thread_id("F")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("G")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + def test_get_thread_id_for_receipts(self) -> None: + """ + Ensure that get_thread_id_for_receipts searches up and down the tree for a thread. + """ + # All of the events are considered related to this thread. + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("A")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("B")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("C")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("D")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("E")) + self.assertEqual("A", thread_id) + + # Events which are not related to a thread at all should return the + # main timeline. + thread_id = self.get_success(self._main_store.get_thread_id("F")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("G")) + self.assertEqual(MAIN_TIMELINE, thread_id) -- cgit 1.5.1 From 40bb37eb27e1841754a297ac1277748de7f6c1cb Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Sat, 15 Oct 2022 00:36:49 -0500 Subject: Stop getting missing `prev_events` after we already know their signature is invalid (#13816) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit While https://github.com/matrix-org/synapse/pull/13635 stops us from doing the slow thing after we've already done it once, this PR stops us from doing one of the slow things in the first place. Related to - https://github.com/matrix-org/synapse/issues/13622 - https://github.com/matrix-org/synapse/pull/13635 - https://github.com/matrix-org/synapse/issues/13676 Part of https://github.com/matrix-org/synapse/issues/13356 Follow-up to https://github.com/matrix-org/synapse/pull/13815 which tracks event signature failures. With this PR, we avoid the call to the costly `_get_state_ids_after_missing_prev_event` because the signature failure will count as an attempt before and we filter events based on the backoff before calling `_get_state_ids_after_missing_prev_event` now. For example, this will save us 156s out of the 185s total that this `matrix.org` `/messages` request. If you want to see the full Jaeger trace of this, you can drag and drop this `trace.json` into your own Jaeger, https://gist.github.com/MadLittleMods/4b12d0d0afe88c2f65ffcc907306b761 To explain this exact scenario around `/messages` -> backfill, we call `/backfill` and first check the signatures of the 100 events. We see bad signature for `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` and `$zuOn2Rd2vsC7SUia3Hp3r6JSkSFKcc5j3QTTqW_0jDw` (both member events). Then we process the 98 events remaining that have valid signatures but one of the events references `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` as a `prev_event`. So we have to do the whole `_get_state_ids_after_missing_prev_event` rigmarole which pulls in those same events which fail again because the signatures are still invalid. - `backfill` - `outgoing-federation-request` `/backfill` - `_check_sigs_and_hash_and_fetch` - `_check_sigs_and_hash_and_fetch_one` for each event received over backfill - ❗ `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` fails with `Signature on retrieved event was invalid.`: `unable to verify signature for sender domain xxx: 401: Failed to find any key to satisfy: _FetchKeyRequest(...)` - ❗ `$zuOn2Rd2vsC7SUia3Hp3r6JSkSFKcc5j3QTTqW_0jDw` fails with `Signature on retrieved event was invalid.`: `unable to verify signature for sender domain xxx: 401: Failed to find any key to satisfy: _FetchKeyRequest(...)` - `_process_pulled_events` - `_process_pulled_event` for each validated event - ❗ Event `$Q0iMdqtz3IJYfZQU2Xk2WjB5NDF8Gg8cFSYYyKQgKJ0` references `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` as a `prev_event` which is missing so we try to get it - `_get_state_ids_after_missing_prev_event` - `outgoing-federation-request` `/state_ids` - ❗ `get_pdu` for `$luA4l7QHhf_jadH3mI-AyFqho0U2Q-IXXUbGSMq6h6M` which fails the signature check again - ❗ `get_pdu` for `$zuOn2Rd2vsC7SUia3Hp3r6JSkSFKcc5j3QTTqW_0jDw` which fails the signature check --- changelog.d/13816.feature | 1 + synapse/api/errors.py | 21 +++ synapse/handlers/federation.py | 16 ++ synapse/handlers/federation_event.py | 31 ++++ synapse/storage/databases/main/event_federation.py | 54 ++++++ tests/handlers/test_federation_event.py | 201 ++++++++++++++++++++- tests/storage/test_event_federation.py | 64 +++++++ 7 files changed, 386 insertions(+), 2 deletions(-) create mode 100644 changelog.d/13816.feature (limited to 'tests/storage') diff --git a/changelog.d/13816.feature b/changelog.d/13816.feature new file mode 100644 index 0000000000..5eaa936b08 --- /dev/null +++ b/changelog.d/13816.feature @@ -0,0 +1 @@ +Stop fetching missing `prev_events` after we already know their signature is invalid. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index c606207569..e0873b1913 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -640,6 +640,27 @@ class FederationError(RuntimeError): } +class FederationPullAttemptBackoffError(RuntimeError): + """ + Raised to indicate that we are are deliberately not attempting to pull the given + event over federation because we've already done so recently and are backing off. + + Attributes: + event_id: The event_id which we are refusing to pull + message: A custom error message that gives more context + """ + + def __init__(self, event_ids: List[str], message: Optional[str]): + self.event_ids = event_ids + + if message: + error_message = message + else: + error_message = f"Not attempting to pull event_ids={self.event_ids} because we already tried to pull them recently (backing off)." + + super().__init__(error_message) + + class HttpResponseException(CodeMessageException): """ Represents an HTTP-level failure of an outbound request diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 44e70c6c3c..5f7e0a1f79 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -45,6 +45,7 @@ from synapse.api.errors import ( Codes, FederationDeniedError, FederationError, + FederationPullAttemptBackoffError, HttpResponseException, LimitExceededError, NotFoundError, @@ -1720,7 +1721,22 @@ class FederationHandler: destination, event ) break + except FederationPullAttemptBackoffError as exc: + # Log a warning about why we failed to process the event (the error message + # for `FederationPullAttemptBackoffError` is pretty good) + logger.warning("_sync_partial_state_room: %s", exc) + # We do not record a failed pull attempt when we backoff fetching a missing + # `prev_event` because not being able to fetch the `prev_events` just means + # we won't be able to de-outlier the pulled event. But we can still use an + # `outlier` in the state/auth chain for another event. So we shouldn't stop + # a downstream event from trying to pull it. + # + # This avoids a cascade of backoff for all events in the DAG downstream from + # one event backoff upstream. except FederationError as e: + # TODO: We should `record_event_failed_pull_attempt` here, + # see https://github.com/matrix-org/synapse/issues/13700 + if attempt == len(destinations) - 1: # We have tried every remote server for this event. Give up. # TODO(faster_joins) giving up isn't the right thing to do diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index f382961099..4300e8dd40 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -44,6 +44,7 @@ from synapse.api.errors import ( AuthError, Codes, FederationError, + FederationPullAttemptBackoffError, HttpResponseException, RequestSendFailed, SynapseError, @@ -567,6 +568,9 @@ class FederationEventHandler: event: partial-state event to be de-partial-stated Raises: + FederationPullAttemptBackoffError if we are are deliberately not attempting + to pull the given event over federation because we've already done so + recently and are backing off. FederationError if we fail to request state from the remote server. """ logger.info("Updating state for %s", event.event_id) @@ -901,6 +905,18 @@ class FederationEventHandler: context, backfilled=backfilled, ) + except FederationPullAttemptBackoffError as exc: + # Log a warning about why we failed to process the event (the error message + # for `FederationPullAttemptBackoffError` is pretty good) + logger.warning("_process_pulled_event: %s", exc) + # We do not record a failed pull attempt when we backoff fetching a missing + # `prev_event` because not being able to fetch the `prev_events` just means + # we won't be able to de-outlier the pulled event. But we can still use an + # `outlier` in the state/auth chain for another event. So we shouldn't stop + # a downstream event from trying to pull it. + # + # This avoids a cascade of backoff for all events in the DAG downstream from + # one event backoff upstream. except FederationError as e: await self._store.record_event_failed_pull_attempt( event.room_id, event_id, str(e) @@ -947,6 +963,9 @@ class FederationEventHandler: The event context. Raises: + FederationPullAttemptBackoffError if we are are deliberately not attempting + to pull the given event over federation because we've already done so + recently and are backing off. FederationError if we fail to get the state from the remote server after any missing `prev_event`s. """ @@ -957,6 +976,18 @@ class FederationEventHandler: seen = await self._store.have_events_in_timeline(prevs) missing_prevs = prevs - seen + # If we've already recently attempted to pull this missing event, don't + # try it again so soon. Since we have to fetch all of the prev_events, we can + # bail early here if we find any to ignore. + prevs_to_ignore = await self._store.get_event_ids_to_not_pull_from_backoff( + room_id, missing_prevs + ) + if len(prevs_to_ignore) > 0: + raise FederationPullAttemptBackoffError( + event_ids=prevs_to_ignore, + message=f"While computing context for event={event_id}, not attempting to pull missing prev_event={prevs_to_ignore[0]} because we already tried to pull recently (backing off).", + ) + if not missing_prevs: return await self._state_handler.compute_event_context(event) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 6b9a629edd..309a4ba664 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1501,6 +1501,12 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas event_id: The event that failed to be fetched or processed cause: The error message or reason that we failed to pull the event """ + logger.debug( + "record_event_failed_pull_attempt room_id=%s, event_id=%s, cause=%s", + room_id, + event_id, + cause, + ) await self.db_pool.runInteraction( "record_event_failed_pull_attempt", self._record_event_failed_pull_attempt_upsert_txn, @@ -1530,6 +1536,54 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas txn.execute(sql, (room_id, event_id, 1, self._clock.time_msec(), cause)) + @trace + async def get_event_ids_to_not_pull_from_backoff( + self, + room_id: str, + event_ids: Collection[str], + ) -> List[str]: + """ + Filter down the events to ones that we've failed to pull before recently. Uses + exponential backoff. + + Args: + room_id: The room that the events belong to + event_ids: A list of events to filter down + + Returns: + List of event_ids that should not be attempted to be pulled + """ + event_failed_pull_attempts = await self.db_pool.simple_select_many_batch( + table="event_failed_pull_attempts", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=( + "event_id", + "last_attempt_ts", + "num_attempts", + ), + desc="get_event_ids_to_not_pull_from_backoff", + ) + + current_time = self._clock.time_msec() + return [ + event_failed_pull_attempt["event_id"] + for event_failed_pull_attempt in event_failed_pull_attempts + # Exponential back-off (up to the upper bound) so we don't try to + # pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. + if current_time + < event_failed_pull_attempt["last_attempt_ts"] + + ( + 2 + ** min( + event_failed_pull_attempt["num_attempts"], + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, + ) + ) + * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS + ] + async def get_missing_events( self, room_id: str, diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 918010cddb..e448cb1901 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -14,7 +14,7 @@ from typing import Optional from unittest import mock -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, StoreError from synapse.api.room_versions import RoomVersion from synapse.event_auth import ( check_state_dependent_auth_rules, @@ -43,7 +43,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): def make_homeserver(self, reactor, clock): # mock out the federation transport client self.mock_federation_transport_client = mock.Mock( - spec=["get_room_state_ids", "get_room_state", "get_event"] + spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"] ) return super().setup_test_homeserver( federation_transport_client=self.mock_federation_transport_client @@ -459,6 +459,203 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ) self.assertIsNotNone(persisted, "pulled event was not persisted at all") + def test_backfill_signature_failure_does_not_fetch_same_prev_event_later( + self, + ) -> None: + """ + Test to make sure we backoff and don't try to fetch a missing prev_event when we + already know it has a invalid signature from checking the signatures of all of + the events in the backfill response. + """ + OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" + main_store = self.hs.get_datastores().main + + # Create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(main_store.get_room_version(room_id)) + + # Allow the remote user to send state events + self.helper.send_state( + room_id, + "m.room.power_levels", + {"events_default": 0, "state_default": 0}, + tok=tok, + ) + + # Add the remote user to the room + member_event = self.get_success( + event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join") + ) + + initial_state_map = self.get_success( + main_store.get_partial_current_state_ids(room_id) + ) + + auth_event_ids = [ + initial_state_map[("m.room.create", "")], + initial_state_map[("m.room.power_levels", "")], + member_event.event_id, + ] + + # We purposely don't run `add_hashes_and_signatures_from_other_server` + # over this because we want the signature check to fail. + pulled_event_without_signatures = make_event_from_dict( + { + "type": "test_regular_type", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [member_event.event_id], + "auth_events": auth_event_ids, + "origin_server_ts": 1, + "depth": 12, + "content": {"body": "pulled_event_without_signatures"}, + }, + room_version, + ) + + # Create a regular event that should pass except for the + # `pulled_event_without_signatures` in the `prev_event`. + pulled_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "type": "test_regular_type", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [ + member_event.event_id, + pulled_event_without_signatures.event_id, + ], + "auth_events": auth_event_ids, + "origin_server_ts": 1, + "depth": 12, + "content": {"body": "pulled_event"}, + } + ), + room_version, + ) + + # We expect an outbound request to /backfill, so stub that out + self.mock_federation_transport_client.backfill.return_value = make_awaitable( + { + "origin": self.OTHER_SERVER_NAME, + "origin_server_ts": 123, + "pdus": [ + # This is one of the important aspects of this test: we include + # `pulled_event_without_signatures` so it fails the signature check + # when we filter down the backfill response down to events which + # have valid signatures in + # `_check_sigs_and_hash_for_pulled_events_and_fetch` + pulled_event_without_signatures.get_pdu_json(), + # Then later when we process this valid signature event, when we + # fetch the missing `prev_event`s, we want to make sure that we + # backoff and don't try and fetch `pulled_event_without_signatures` + # again since we know it just had an invalid signature. + pulled_event.get_pdu_json(), + ], + } + ) + + # Keep track of the count and make sure we don't make any of these requests + event_endpoint_requested_count = 0 + room_state_ids_endpoint_requested_count = 0 + room_state_endpoint_requested_count = 0 + + async def get_event( + destination: str, event_id: str, timeout: Optional[int] = None + ) -> None: + nonlocal event_endpoint_requested_count + event_endpoint_requested_count += 1 + + async def get_room_state_ids( + destination: str, room_id: str, event_id: str + ) -> None: + nonlocal room_state_ids_endpoint_requested_count + room_state_ids_endpoint_requested_count += 1 + + async def get_room_state( + room_version: RoomVersion, destination: str, room_id: str, event_id: str + ) -> None: + nonlocal room_state_endpoint_requested_count + room_state_endpoint_requested_count += 1 + + # We don't expect an outbound request to `/event`, `/state_ids`, or `/state` in + # the happy path but if the logic is sneaking around what we expect, stub that + # out so we can detect that failure + self.mock_federation_transport_client.get_event.side_effect = get_event + self.mock_federation_transport_client.get_room_state_ids.side_effect = ( + get_room_state_ids + ) + self.mock_federation_transport_client.get_room_state.side_effect = ( + get_room_state + ) + + # The function under test: try to backfill and process the pulled event + with LoggingContext("test"): + self.get_success( + self.hs.get_federation_event_handler().backfill( + self.OTHER_SERVER_NAME, + room_id, + limit=1, + extremities=["$some_extremity"], + ) + ) + + if event_endpoint_requested_count > 0: + self.fail( + "We don't expect an outbound request to /event in the happy path but if " + "the logic is sneaking around what we expect, make sure to fail the test. " + "We don't expect it because the signature failure should cause us to backoff " + "and not asking about pulled_event_without_signatures=" + f"{pulled_event_without_signatures.event_id} again" + ) + + if room_state_ids_endpoint_requested_count > 0: + self.fail( + "We don't expect an outbound request to /state_ids in the happy path but if " + "the logic is sneaking around what we expect, make sure to fail the test. " + "We don't expect it because the signature failure should cause us to backoff " + "and not asking about pulled_event_without_signatures=" + f"{pulled_event_without_signatures.event_id} again" + ) + + if room_state_endpoint_requested_count > 0: + self.fail( + "We don't expect an outbound request to /state in the happy path but if " + "the logic is sneaking around what we expect, make sure to fail the test. " + "We don't expect it because the signature failure should cause us to backoff " + "and not asking about pulled_event_without_signatures=" + f"{pulled_event_without_signatures.event_id} again" + ) + + # Make sure we only recorded a single failure which corresponds to the signature + # failure initially in `_check_sigs_and_hash_for_pulled_events_and_fetch` before + # we process all of the pulled events. + backfill_num_attempts_for_event_without_signatures = self.get_success( + main_store.db_pool.simple_select_one_onecol( + table="event_failed_pull_attempts", + keyvalues={"event_id": pulled_event_without_signatures.event_id}, + retcol="num_attempts", + ) + ) + self.assertEqual(backfill_num_attempts_for_event_without_signatures, 1) + + # And make sure we didn't record a failure for the event that has the missing + # prev_event because we don't want to cause a cascade of failures. Not being + # able to fetch the `prev_events` just means we won't be able to de-outlier the + # pulled event. But we can still use an `outlier` in the state/auth chain for + # another event. So we shouldn't stop a downstream event from trying to pull it. + self.get_failure( + main_store.db_pool.simple_select_one_onecol( + table="event_failed_pull_attempts", + keyvalues={"event_id": pulled_event.event_id}, + retcol="num_attempts", + ), + # StoreError: 404: No row found + StoreError, + ) + def test_process_pulled_event_with_rejected_missing_state(self) -> None: """Ensure that we correctly handle pulled events with missing state containing a rejected state event diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 59b8910907..853db930d6 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -27,6 +27,8 @@ from synapse.api.room_versions import ( RoomVersion, ) from synapse.events import _EventInternalMetadata +from synapse.rest import admin +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict @@ -43,6 +45,12 @@ class _BackfillSetupInfo: class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main @@ -1122,6 +1130,62 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] self.assertEqual(backfill_event_ids, ["insertion_eventA"]) + def test_get_event_ids_to_not_pull_from_backoff( + self, + ): + """ + Test to make sure only event IDs we should backoff from are returned. + """ + # Create the room + user_id = self.register_user("alice", "test") + tok = self.login("alice", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "$failed_event_id", "fake cause" + ) + ) + + event_ids_to_backoff = self.get_success( + self.store.get_event_ids_to_not_pull_from_backoff( + room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"] + ) + ) + + self.assertEqual(event_ids_to_backoff, ["$failed_event_id"]) + + def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration( + self, + ): + """ + Test to make sure no event IDs are returned after the backoff duration has + elapsed. + """ + # Create the room + user_id = self.register_user("alice", "test") + tok = self.login("alice", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + + self.get_success( + self.store.record_event_failed_pull_attempt( + room_id, "$failed_event_id", "fake cause" + ) + ) + + # Now advance time by 2 hours so we wait long enough for the single failed + # attempt (2^1 hours). + self.reactor.advance(datetime.timedelta(hours=2).total_seconds()) + + event_ids_to_backoff = self.get_success( + self.store.get_event_ids_to_not_pull_from_backoff( + room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"] + ) + ) + # Since this function only returns events we should backoff from, time has + # elapsed past the backoff range so there is no events to backoff from. + self.assertEqual(event_ids_to_backoff, []) + @attr.s class FakeEvent: -- cgit 1.5.1 From 4283bd1cf9c3da2157c3642a7c4f105e9fac2636 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 17 Oct 2022 11:32:11 -0400 Subject: Support filtering the /messages API by relation type (MSC3874). (#14148) Gated behind an experimental configuration flag. --- changelog.d/14148.feature | 1 + synapse/api/filtering.py | 27 +++++- synapse/config/experimental.py | 3 + synapse/rest/client/versions.py | 2 + synapse/storage/databases/main/stream.py | 29 ++++++- tests/api/test_filtering.py | 63 +++++++++++++- tests/rest/client/test_relations.py | 1 - tests/rest/client/test_rooms.py | 145 ++----------------------------- tests/storage/test_stream.py | 118 ++++++++++++++++++------- 9 files changed, 212 insertions(+), 177 deletions(-) create mode 100644 changelog.d/14148.feature (limited to 'tests/storage') diff --git a/changelog.d/14148.feature b/changelog.d/14148.feature new file mode 100644 index 0000000000..951d0cac80 --- /dev/null +++ b/changelog.d/14148.feature @@ -0,0 +1 @@ +Experimental support for [MSC3874](https://github.com/matrix-org/matrix-spec-proposals/pull/3874). diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index cc31cf8cc7..26be377d03 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -36,7 +36,7 @@ from jsonschema import FormatChecker from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState -from synapse.events import EventBase +from synapse.events import EventBase, relation_from_event from synapse.types import JsonDict, RoomID, UserID if TYPE_CHECKING: @@ -53,6 +53,12 @@ FILTER_SCHEMA = { # check types are valid event types "types": {"type": "array", "items": {"type": "string"}}, "not_types": {"type": "array", "items": {"type": "string"}}, + # MSC3874, filtering /messages. + "org.matrix.msc3874.rel_types": {"type": "array", "items": {"type": "string"}}, + "org.matrix.msc3874.not_rel_types": { + "type": "array", + "items": {"type": "string"}, + }, }, } @@ -334,8 +340,15 @@ class Filter: self.labels = filter_json.get("org.matrix.labels", None) self.not_labels = filter_json.get("org.matrix.not_labels", []) - self.related_by_senders = self.filter_json.get("related_by_senders", None) - self.related_by_rel_types = self.filter_json.get("related_by_rel_types", None) + self.related_by_senders = filter_json.get("related_by_senders", None) + self.related_by_rel_types = filter_json.get("related_by_rel_types", None) + + # For compatibility with _check_fields. + self.rel_types = None + self.not_rel_types = [] + if hs.config.experimental.msc3874_enabled: + self.rel_types = filter_json.get("org.matrix.msc3874.rel_types", None) + self.not_rel_types = filter_json.get("org.matrix.msc3874.not_rel_types", []) def filters_all_types(self) -> bool: return "*" in self.not_types @@ -386,11 +399,19 @@ class Filter: # check if there is a string url field in the content for filtering purposes labels = content.get(EventContentFields.LABELS, []) + # Check if the event has a relation. + rel_type = None + if isinstance(event, EventBase): + relation = relation_from_event(event) + if relation: + rel_type = relation.rel_type + field_matchers = { "rooms": lambda v: room_id == v, "senders": lambda v: sender == v, "types": lambda v: _matches_wildcard(ev_type, v), "labels": lambda v: v in labels, + "rel_types": lambda v: rel_type == v, } result = self._check_fields(field_matchers) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f44655516e..f9a49451d8 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -117,3 +117,6 @@ class ExperimentalConfig(Config): self.msc3882_token_timeout = self.parse_duration( experimental.get("msc3882_token_timeout", "5m") ) + + # MSC3874: Filtering /messages with rel_types / not_rel_types. + self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 4e1fd2bbe7..4b87ee978a 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -114,6 +114,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3882": self.config.experimental.msc3882_enabled, # Adds support for remotely enabling/disabling pushers, as per MSC3881 "org.matrix.msc3881": self.config.experimental.msc3881_enabled, + # Adds support for filtering /messages by event relation. + "org.matrix.msc3874": self.config.experimental.msc3874_enabled, }, }, ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 5baffbfe55..09ce855aa8 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -357,6 +357,24 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: ) args.extend(event_filter.related_by_rel_types) + if event_filter.rel_types: + clauses.append( + "(%s)" + % " OR ".join( + "event_relation.relation_type = ?" for _ in event_filter.rel_types + ) + ) + args.extend(event_filter.rel_types) + + if event_filter.not_rel_types: + clauses.append( + "((%s) OR event_relation.relation_type IS NULL)" + % " AND ".join( + "event_relation.relation_type != ?" for _ in event_filter.not_rel_types + ) + ) + args.extend(event_filter.not_rel_types) + return " AND ".join(clauses), args @@ -1278,8 +1296,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # Multiple labels could cause the same event to appear multiple times. needs_distinct = True - # If there is a filter on relation_senders and relation_types join to the - # relations table. + # If there is a relation_senders and relation_types filter join to the + # relations table to get events related to the current event. if event_filter and ( event_filter.related_by_senders or event_filter.related_by_rel_types ): @@ -1294,6 +1312,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ + # If there is a not_rel_types filter join to the relations table to get + # the event's relation information. + if event_filter and (event_filter.rel_types or event_filter.not_rel_types): + join_clause += """ + LEFT JOIN event_relations AS event_relation USING (event_id) + """ + if needs_distinct: select_keywords += " DISTINCT" diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index a269c477fb..a82c4eed86 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -35,6 +35,8 @@ def MockEvent(**kwargs): kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: kwargs["type"] = "fake_type" + if "content" not in kwargs: + kwargs["content"] = {} return make_event_from_dict(kwargs) @@ -357,6 +359,66 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertTrue(Filter(self.hs, definition)._check(event)) + @unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) + def test_filter_rel_type(self): + definition = {"org.matrix.msc3874.rel_types": ["m.thread"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.reference"}}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.thread"}}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + + @unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) + def test_filter_not_rel_type(self): + definition = {"org.matrix.msc3874.not_rel_types": ["m.thread"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.thread"}}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.reference"}}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} filter_id = self.get_success( @@ -456,7 +518,6 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertEqual(filtered_room_ids, ["!allowed:example.com"]) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_filter_relations(self): events = [ # An event without a relation. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index f5c1070b2c..ddf315b894 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1677,7 +1677,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_redact_parent_thread(self) -> None: """ Test that thread replies are still available when the root event is redacted. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 3612ebe7b9..71b1637be8 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -35,7 +35,6 @@ from synapse.api.constants import ( EventTypes, Membership, PublicRoomsFilterFields, - RelationTypes, RoomTypes, ) from synapse.api.errors import Codes, HttpResponseException @@ -50,6 +49,7 @@ from synapse.util.stringutils import random_string from tests import unittest from tests.http.server._base import make_request_with_cancellation_test +from tests.storage.test_stream import PaginationTestCase from tests.test_utils import make_awaitable PATH_PREFIX = b"/_matrix/client/api/v1" @@ -2915,149 +2915,20 @@ class LabelsTestCase(unittest.HomeserverTestCase): return event_id -class RelationsTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def default_config(self) -> Dict[str, Any]: - config = super().default_config() - config["experimental_features"] = {"msc3440_enabled": True} - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.user_id = self.register_user("test", "test") - self.tok = self.login("test", "test") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - self.second_user_id = self.register_user("second", "test") - self.second_tok = self.login("second", "test") - self.helper.join( - room=self.room_id, user=self.second_user_id, tok=self.second_tok - ) - - self.third_user_id = self.register_user("third", "test") - self.third_tok = self.login("third", "test") - self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) - - # An initial event with a relation from second user. - res = self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "Message 1"}, - tok=self.tok, - ) - self.event_id_1 = res["event_id"] - self.helper.send_event( - room_id=self.room_id, - type="m.reaction", - content={ - "m.relates_to": { - "rel_type": RelationTypes.ANNOTATION, - "event_id": self.event_id_1, - "key": "👍", - } - }, - tok=self.second_tok, - ) - - # Another event with a relation from third user. - res = self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "Message 2"}, - tok=self.tok, - ) - self.event_id_2 = res["event_id"] - self.helper.send_event( - room_id=self.room_id, - type="m.reaction", - content={ - "m.relates_to": { - "rel_type": RelationTypes.REFERENCE, - "event_id": self.event_id_2, - } - }, - tok=self.third_tok, - ) - - # An event with no relations. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "No relations"}, - tok=self.tok, - ) - - def _filter_messages(self, filter: JsonDict) -> List[JsonDict]: +class RelationsTestCase(PaginationTestCase): + def _filter_messages(self, filter: JsonDict) -> List[str]: """Make a request to /messages with a filter, returns the chunk of events.""" + from_token = self.get_success( + self.from_token.to_string(self.hs.get_datastores().main) + ) channel = self.make_request( "GET", - "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)), + f"/rooms/{self.room_id}/messages?filter={json.dumps(filter)}&dir=f&from={from_token}", access_token=self.tok, ) self.assertEqual(channel.code, HTTPStatus.OK, channel.result) - return channel.json_body["chunk"] - - def test_filter_relation_senders(self) -> None: - # Messages which second user reacted to. - filter = {"related_by_senders": [self.second_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) - - # Messages which third user reacted to. - filter = {"related_by_senders": [self.third_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_2) - - # Messages which either user reacted to. - filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] - ) - - def test_filter_relation_type(self) -> None: - # Messages which have annotations. - filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) - - # Messages which have references. - filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_2) - - # Messages which have either annotations or references. - filter = { - "related_by_rel_types": [ - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - ] - } - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] - ) - - def test_filter_relation_senders_and_type(self) -> None: - # Messages which second user reacted to. - filter = { - "related_by_senders": [self.second_user_id], - "related_by_rel_types": [RelationTypes.ANNOTATION], - } - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) + return [ev["event_id"] for ev in channel.json_body["chunk"]] class ContextTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 78663a53fe..34fa810cf6 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -16,7 +16,6 @@ from typing import List from synapse.api.constants import EventTypes, RelationTypes from synapse.api.filtering import Filter -from synapse.events import EventBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.types import JsonDict @@ -40,7 +39,7 @@ class PaginationTestCase(HomeserverTestCase): def default_config(self): config = super().default_config() - config["experimental_features"] = {"msc3440_enabled": True} + config["experimental_features"] = {"msc3874_enabled": True} return config def prepare(self, reactor, clock, homeserver): @@ -58,6 +57,11 @@ class PaginationTestCase(HomeserverTestCase): self.third_tok = self.login("third", "test") self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) + # Store a token which is after all the room creation events. + self.from_token = self.get_success( + self.hs.get_event_sources().get_current_token_for_pagination(self.room_id) + ) + # An initial event with a relation from second user. res = self.helper.send_event( room_id=self.room_id, @@ -66,7 +70,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.tok, ) self.event_id_1 = res["event_id"] - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type="m.reaction", content={ @@ -78,6 +82,7 @@ class PaginationTestCase(HomeserverTestCase): }, tok=self.second_tok, ) + self.event_id_annotation = res["event_id"] # Another event with a relation from third user. res = self.helper.send_event( @@ -87,7 +92,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.tok, ) self.event_id_2 = res["event_id"] - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type="m.reaction", content={ @@ -98,68 +103,59 @@ class PaginationTestCase(HomeserverTestCase): }, tok=self.third_tok, ) + self.event_id_reference = res["event_id"] # An event with no relations. - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type=EventTypes.Message, content={"msgtype": "m.text", "body": "No relations"}, tok=self.tok, ) + self.event_id_none = res["event_id"] - def _filter_messages(self, filter: JsonDict) -> List[EventBase]: + def _filter_messages(self, filter: JsonDict) -> List[str]: """Make a request to /messages with a filter, returns the chunk of events.""" - from_token = self.get_success( - self.hs.get_event_sources().get_current_token_for_pagination(self.room_id) - ) - events, next_key = self.get_success( self.hs.get_datastores().main.paginate_room_events( room_id=self.room_id, - from_key=from_token.room_key, + from_key=self.from_token.room_key, to_key=None, - direction="b", + direction="f", limit=10, event_filter=Filter(self.hs, filter), ) ) - return events + return [ev.event_id for ev in events] def test_filter_relation_senders(self): # Messages which second user reacted to. filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) # Messages which third user reacted to. filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_2) + self.assertEqual(chunk, [self.event_id_2]) # Messages which either user reacted to. filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c.event_id for c in chunk], [self.event_id_1, self.event_id_2] - ) + self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) def test_filter_relation_type(self): # Messages which have annotations. filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) # Messages which have references. filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_2) + self.assertEqual(chunk, [self.event_id_2]) # Messages which have either annotations or references. filter = { @@ -169,10 +165,7 @@ class PaginationTestCase(HomeserverTestCase): ] } chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c.event_id for c in chunk], [self.event_id_1, self.event_id_2] - ) + self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) def test_filter_relation_senders_and_type(self): # Messages which second user reacted to. @@ -181,8 +174,7 @@ class PaginationTestCase(HomeserverTestCase): "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) def test_duplicate_relation(self): """An event should only be returned once if there are multiple relations to it.""" @@ -201,5 +193,65 @@ class PaginationTestCase(HomeserverTestCase): filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) + + def test_filter_rel_types(self) -> None: + # Messages which are annotations. + filter = {"org.matrix.msc3874.rel_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_annotation]) + + # Messages which are references. + filter = {"org.matrix.msc3874.rel_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_reference]) + + # Messages which are either annotations or references. + filter = { + "org.matrix.msc3874.rel_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertCountEqual( + chunk, + [self.event_id_annotation, self.event_id_reference], + ) + + def test_filter_not_rel_types(self) -> None: + # Messages which are not annotations. + filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual( + chunk, + [ + self.event_id_1, + self.event_id_2, + self.event_id_reference, + self.event_id_none, + ], + ) + + # Messages which are not references. + filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual( + chunk, + [ + self.event_id_1, + self.event_id_annotation, + self.event_id_2, + self.event_id_none, + ], + ) + + # Messages which are neither annotations or references. + filter = { + "org.matrix.msc3874.not_rel_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_1, self.event_id_2, self.event_id_none]) -- cgit 1.5.1 From 828b5502cfdf4f1b20750941714ce95cdb242f0d Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 18 Oct 2022 10:33:21 +0100 Subject: Remove `_get_events_cache` check optimisation from `_have_seen_events_dict` (#14161) --- changelog.d/14161.bugfix | 1 + synapse/storage/databases/main/events_worker.py | 31 +++++++++------------- tests/storage/databases/main/test_events_worker.py | 12 --------- 3 files changed, 14 insertions(+), 30 deletions(-) create mode 100644 changelog.d/14161.bugfix (limited to 'tests/storage') diff --git a/changelog.d/14161.bugfix b/changelog.d/14161.bugfix new file mode 100644 index 0000000000..aed4d9e386 --- /dev/null +++ b/changelog.d/14161.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.30.0 where purging and rejoining a room without restarting in-between would result in a broken room. \ No newline at end of file diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index d4104462b5..cfd4780add 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1502,21 +1502,15 @@ class EventsWorkerStore(SQLBaseStore): Returns: a dict {event_id -> bool} """ - # if the event cache contains the event, obviously we've seen it. - - cache_results = { - event_id - for event_id in event_ids - if await self._get_event_cache.contains((event_id,)) - } - results = dict.fromkeys(cache_results, True) - remaining = [ - event_id for event_id in event_ids if event_id not in cache_results - ] - if not remaining: - return results + # TODO: We used to query the _get_event_cache here as a fast-path before + # hitting the database. For if an event were in the cache, we've presumably + # seen it before. + # + # But this is currently an invalid assumption due to the _get_event_cache + # not being invalidated when purging events from a room. The optimisation can + # be re-added after https://github.com/matrix-org/synapse/issues/13476 - def have_seen_events_txn(txn: LoggingTransaction) -> None: + def have_seen_events_txn(txn: LoggingTransaction) -> Dict[str, bool]: # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1524,16 +1518,17 @@ class EventsWorkerStore(SQLBaseStore): sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", remaining + txn.database_engine, "e.event_id", event_ids ) txn.execute(sql + clause, args) found_events = {eid for eid, in txn} # ... and then we can update the results for each key - results.update({eid: (eid in found_events) for eid in remaining}) + return {eid: (eid in found_events) for eid in event_ids} - await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn) - return results + return await self.db_pool.runInteraction( + "have_seen_events", have_seen_events_txn + ) @cached(max_entries=100000, tree=True) async def have_seen_event(self, room_id: str, event_id: str) -> bool: diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 32a798d74b..5773172ab8 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -90,18 +90,6 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) - def test_query_via_event_cache(self): - # fetch an event into the event cache - self.get_success(self.store.get_event(self.event_ids[0])) - - # looking it up should now cause no db hits - with LoggingContext(name="test") as ctx: - res = self.get_success( - self.store.have_seen_events(self.room_id, [self.event_ids[0]]) - ) - self.assertEqual(res, {self.event_ids[0]}) - self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) - def test_persisting_event_invalidates_cache(self): """ Test to make sure that the `have_seen_event` cache -- cgit 1.5.1 From d902181de98399d90c46c4e4e2cf631064757941 Mon Sep 17 00:00:00 2001 From: James Salter Date: Tue, 25 Oct 2022 19:05:22 +0100 Subject: Unified search query syntax using the full-text search capabilities of the underlying DB. (#11635) Support a unified search query syntax which leverages more of the full-text search of each database supported by Synapse. Supports, with the same syntax across Postgresql 11+ and Sqlite: - quoted "search terms" - `AND`, `OR`, `-` (negation) operators - Matching words based on their stem, e.g. searches for "dog" matches documents containing "dogs". This is achieved by - If on postgresql 11+, pass the user input to `websearch_to_tsquery` - If on sqlite, manually parse the query and transform it into the sqlite-specific query syntax. Note that postgresql 10, which is close to end-of-life, falls back to using `phraseto_tsquery`, which only supports a subset of the features. Multiple terms separated by a space are implicitly ANDed. Note that: 1. There is no escaping of full-text syntax that might be supported by the database; e.g. `NOT`, `NEAR`, `*` in sqlite. This runs the risk that people might discover this as accidental functionality and depend on something we don't guarantee. 2. English text is assumed for stemming. To support other languages, either the target language needs to be known at the time of indexing the message (via room metadata, or otherwise), or a separate index for each language supported could be created. Sqlite docs: https://www.sqlite.org/fts3.html#full_text_index_queries Postgres docs: https://www.postgresql.org/docs/11/textsearch-controls.html --- changelog.d/11635.feature | 1 + synapse/storage/databases/main/search.py | 197 +++++++++++++++---- synapse/storage/engines/postgres.py | 16 ++ .../delta/73/10_update_sqlite_fts4_tokenizer.py | 62 ++++++ tests/storage/test_room_search.py | 213 +++++++++++++++++++++ 5 files changed, 454 insertions(+), 35 deletions(-) create mode 100644 changelog.d/11635.feature create mode 100644 synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py (limited to 'tests/storage') diff --git a/changelog.d/11635.feature b/changelog.d/11635.feature new file mode 100644 index 0000000000..94c8a83212 --- /dev/null +++ b/changelog.d/11635.feature @@ -0,0 +1 @@ +Allow use of postgres and sqllite full-text search operators in search queries. \ No newline at end of file diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 1b79acf955..a89fc54c2c 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -11,10 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum import logging import re -from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple +from collections import deque +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) import attr @@ -27,7 +39,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import JsonDict if TYPE_CHECKING: @@ -421,8 +433,6 @@ class SearchStore(SearchBackgroundUpdateStore): """ clauses = [] - search_query = _parse_query(self.database_engine, search_term) - args: List[Any] = [] # Make sure we don't explode because the person is in too many rooms. @@ -444,20 +454,24 @@ class SearchStore(SearchBackgroundUpdateStore): count_clauses = clauses if isinstance(self.database_engine, PostgresEngine): + search_query = search_term + tsquery_func = self.database_engine.tsquery_func sql = ( - "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank," + f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank," " room_id, event_id" " FROM event_search" - " WHERE vector @@ to_tsquery('english', ?)" + f" WHERE vector @@ {tsquery_func}('english', ?)" ) args = [search_query, search_query] + args count_sql = ( "SELECT room_id, count(*) as count FROM event_search" - " WHERE vector @@ to_tsquery('english', ?)" + f" WHERE vector @@ {tsquery_func}('english', ?)" ) count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): + search_query = _parse_query_for_sqlite(search_term) + sql = ( "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" " FROM event_search" @@ -469,7 +483,7 @@ class SearchStore(SearchBackgroundUpdateStore): "SELECT room_id, count(*) as count FROM event_search" " WHERE value MATCH ?" ) - count_args = [search_term] + count_args + count_args = [search_query] + count_args else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -501,7 +515,9 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = await self._find_highlights_in_postgres(search_query, events) + highlights = await self._find_highlights_in_postgres( + search_query, events, tsquery_func + ) count_sql += " GROUP BY room_id" @@ -510,7 +526,6 @@ class SearchStore(SearchBackgroundUpdateStore): ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) - return { "results": [ {"event": event_map[r["event_id"]], "rank": r["rank"]} @@ -542,9 +557,6 @@ class SearchStore(SearchBackgroundUpdateStore): Each match as a dictionary. """ clauses = [] - - search_query = _parse_query(self.database_engine, search_term) - args: List[Any] = [] # Make sure we don't explode because the person is in too many rooms. @@ -582,20 +594,23 @@ class SearchStore(SearchBackgroundUpdateStore): args.extend([origin_server_ts, origin_server_ts, stream]) if isinstance(self.database_engine, PostgresEngine): + search_query = search_term + tsquery_func = self.database_engine.tsquery_func sql = ( - "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank," + f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank," " origin_server_ts, stream_ordering, room_id, event_id" " FROM event_search" - " WHERE vector @@ to_tsquery('english', ?) AND " + f" WHERE vector @@ {tsquery_func}('english', ?) AND " ) args = [search_query, search_query] + args count_sql = ( "SELECT room_id, count(*) as count FROM event_search" - " WHERE vector @@ to_tsquery('english', ?) AND " + f" WHERE vector @@ {tsquery_func}('english', ?) AND " ) count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): + # We use CROSS JOIN here to ensure we use the right indexes. # https://sqlite.org/optoverview.html#crossjoin # @@ -614,13 +629,14 @@ class SearchStore(SearchBackgroundUpdateStore): " CROSS JOIN events USING (event_id)" " WHERE " ) + search_query = _parse_query_for_sqlite(search_term) args = [search_query] + args count_sql = ( "SELECT room_id, count(*) as count FROM event_search" " WHERE value MATCH ? AND " ) - count_args = [search_term] + count_args + count_args = [search_query] + count_args else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -660,7 +676,9 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = await self._find_highlights_in_postgres(search_query, events) + highlights = await self._find_highlights_in_postgres( + search_query, events, tsquery_func + ) count_sql += " GROUP BY room_id" @@ -686,7 +704,7 @@ class SearchStore(SearchBackgroundUpdateStore): } async def _find_highlights_in_postgres( - self, search_query: str, events: List[EventBase] + self, search_query: str, events: List[EventBase], tsquery_func: str ) -> Set[str]: """Given a list of events and a search term, return a list of words that match from the content of the event. @@ -697,6 +715,7 @@ class SearchStore(SearchBackgroundUpdateStore): Args: search_query events: A list of events + tsquery_func: The tsquery_* function to use when making queries Returns: A set of strings. @@ -729,7 +748,7 @@ class SearchStore(SearchBackgroundUpdateStore): while stop_sel in value: stop_sel += ">" - query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % ( + query = f"SELECT ts_headline(?, {tsquery_func}('english', ?), %s)" % ( _to_postgres_options( { "StartSel": start_sel, @@ -760,20 +779,128 @@ def _to_postgres_options(options_dict: JsonDict) -> str: return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),) -def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str: - """Takes a plain unicode string from the user and converts it into a form - that can be passed to database. - We use this so that we can add prefix matching, which isn't something - that is supported by default. +@dataclass +class Phrase: + phrase: List[str] + + +class SearchToken(enum.Enum): + Not = enum.auto() + Or = enum.auto() + And = enum.auto() + + +Token = Union[str, Phrase, SearchToken] +TokenList = List[Token] + + +def _is_stop_word(word: str) -> bool: + # TODO Pull these out of the dictionary: + # https://github.com/postgres/postgres/blob/master/src/backend/snowball/stopwords/english.stop + return word in {"the", "a", "you", "me", "and", "but"} + + +def _tokenize_query(query: str) -> TokenList: + """ + Convert the user-supplied `query` into a TokenList, which can be translated into + some DB-specific syntax. + + The following constructs are supported: + + - phrase queries using "double quotes" + - case-insensitive `or` and `and` operators + - negation of a keyword via unary `-` + - unary hyphen to denote NOT e.g. 'include -exclude' + + The following differs from websearch_to_tsquery: + + - Stop words are not removed. + - Unclosed phrases are treated differently. + + """ + tokens: TokenList = [] + + # Find phrases. + in_phrase = False + parts = deque(query.split('"')) + for i, part in enumerate(parts): + # The contents inside double quotes is treated as a phrase, a trailing + # double quote is not implied. + in_phrase = bool(i % 2) and i != (len(parts) - 1) + + # Pull out the individual words, discarding any non-word characters. + words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE)) + + # Phrases have simplified handling of words. + if in_phrase: + # Skip stop words. + phrase = [word for word in words if not _is_stop_word(word)] + + # Consecutive words are implicitly ANDed together. + if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or): + tokens.append(SearchToken.And) + + # Add the phrase. + tokens.append(Phrase(phrase)) + continue + + # Otherwise, not in a phrase. + while words: + word = words.popleft() + + if word.startswith("-"): + tokens.append(SearchToken.Not) + + # If there's more word, put it back to be processed again. + word = word[1:] + if word: + words.appendleft(word) + elif word.lower() == "or": + tokens.append(SearchToken.Or) + else: + # Skip stop words. + if _is_stop_word(word): + continue + + # Consecutive words are implicitly ANDed together. + if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or): + tokens.append(SearchToken.And) + + # Add the search term. + tokens.append(word) + + return tokens + + +def _tokens_to_sqlite_match_query(tokens: TokenList) -> str: + """ + Convert the list of tokens to a string suitable for passing to sqlite's MATCH. + Assume sqlite was compiled with enhanced query syntax. + + Ref: https://www.sqlite.org/fts3.html#full_text_index_queries """ + match_query = [] + for token in tokens: + if isinstance(token, str): + match_query.append(token) + elif isinstance(token, Phrase): + match_query.append('"' + " ".join(token.phrase) + '"') + elif token == SearchToken.Not: + # TODO: SQLite treats NOT as a *binary* operator. Hopefully a search + # term has already been added before this. + match_query.append(" NOT ") + elif token == SearchToken.Or: + match_query.append(" OR ") + elif token == SearchToken.And: + match_query.append(" AND ") + else: + raise ValueError(f"unknown token {token}") + + return "".join(match_query) - # Pull out the individual words, discarding any non-word characters. - results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) - if isinstance(database_engine, PostgresEngine): - return " & ".join(result + ":*" for result in results) - elif isinstance(database_engine, Sqlite3Engine): - return " & ".join(result + "*" for result in results) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") +def _parse_query_for_sqlite(search_term: str) -> str: + """Takes a plain unicode string from the user and converts it into a form + that can be passed to sqllite's matchinfo(). + """ + return _tokens_to_sqlite_match_query(_tokenize_query(search_term)) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index d8c0f64d9a..9bf74bbf59 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -170,6 +170,22 @@ class PostgresEngine( """Do we support the `RETURNING` clause in insert/update/delete?""" return True + @property + def tsquery_func(self) -> str: + """ + Selects a tsquery_* func to use. + + Ref: https://www.postgresql.org/docs/current/textsearch-controls.html + + Returns: + The function name. + """ + # Postgres 11 added support for websearch_to_tsquery. + assert self._version is not None + if self._version >= 110000: + return "websearch_to_tsquery" + return "plainto_tsquery" + def is_deadlock(self, error: Exception) -> bool: if isinstance(error, psycopg2.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html diff --git a/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py b/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py new file mode 100644 index 0000000000..3de0a709eb --- /dev/null +++ b/synapse/storage/schema/main/delta/73/10_update_sqlite_fts4_tokenizer.py @@ -0,0 +1,62 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +from synapse.storage.engines import BaseDatabaseEngine, Sqlite3Engine +from synapse.storage.types import Cursor + + +def run_create(cur: Cursor, database_engine: BaseDatabaseEngine) -> None: + """ + Upgrade the event_search table to use the porter tokenizer if it isn't already + + Applies only for sqlite. + """ + if not isinstance(database_engine, Sqlite3Engine): + return + + # Rebuild the table event_search table with tokenize=porter configured. + cur.execute("DROP TABLE event_search") + cur.execute( + """ + CREATE VIRTUAL TABLE event_search + USING fts4 (tokenize=porter, event_id, room_id, sender, key, value ) + """ + ) + + # Re-run the background job to re-populate the event_search table. + cur.execute("SELECT MIN(stream_ordering) FROM events") + row = cur.fetchone() + min_stream_id = row[0] + + # If there are not any events, nothing to do. + if min_stream_id is None: + return + + cur.execute("SELECT MAX(stream_ordering) FROM events") + row = cur.fetchone() + max_stream_id = row[0] + + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + } + progress_json = json.dumps(progress) + + sql = """ + INSERT into background_updates (ordering, update_name, progress_json) + VALUES (?, ?, ?) + """ + + cur.execute(sql, (7310, "event_search", progress_json)) diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index e747c6b50e..9ddc19900a 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -12,11 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple, Union +from unittest.case import SkipTest +from unittest.mock import PropertyMock, patch + +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.storage.databases.main import DataStore +from synapse.storage.databases.main.search import Phrase, SearchToken, _tokenize_query from synapse.storage.engines import PostgresEngine +from synapse.storage.engines.sqlite import Sqlite3Engine +from synapse.util import Clock from tests.unittest import HomeserverTestCase, skip_unless from tests.utils import USE_POSTGRES_FOR_TESTS @@ -187,3 +198,205 @@ class EventSearchInsertionTest(HomeserverTestCase): ), ) self.assertCountEqual(values, ["hi", "2"]) + + +class MessageSearchTest(HomeserverTestCase): + """ + Check message search. + + A powerful way to check the behaviour is to run the following in Postgres >= 11: + + # SELECT websearch_to_tsquery('english', ); + + The result can be compared to the tokenized version for SQLite and Postgres < 11. + + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + ] + + PHRASE = "the quick brown fox jumps over the lazy dog" + + # Each entry is a search query, followed by either a boolean of whether it is + # in the phrase OR a tuple of booleans: whether it matches using websearch + # and using plain search. + COMMON_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [ + ("nope", False), + ("brown", True), + ("quick brown", True), + ("brown quick", True), + ("quick \t brown", True), + ("jump", True), + ("brown nope", False), + ('"brown quick"', (False, True)), + ('"jumps over"', True), + ('"quick fox"', (False, True)), + ("nope OR doublenope", False), + ("furphy OR fox", (True, False)), + ("fox -nope", (True, False)), + ("fox -brown", (False, True)), + ('"fox" quick', True), + ('"fox quick', True), + ('"quick brown', True), + ('" quick "', True), + ('" nope"', False), + ] + # TODO Test non-ASCII cases. + + # Case that fail on SQLite. + POSTGRES_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [ + # SQLite treats NOT as a binary operator. + ("- fox", (False, True)), + ("- nope", (True, False)), + ('"-fox quick', (False, True)), + # PostgreSQL skips stop words. + ('"the quick brown"', True), + ('"over lazy"', True), + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # Register a user and create a room, create some messages + self.register_user("alice", "password") + self.access_token = self.login("alice", "password") + self.room_id = self.helper.create_room_as("alice", tok=self.access_token) + + # Send the phrase as a message and check it was created + response = self.helper.send(self.room_id, self.PHRASE, tok=self.access_token) + self.assertIn("event_id", response) + + def test_tokenize_query(self) -> None: + """Test the custom logic to tokenize a user's query.""" + cases = ( + ("brown", ["brown"]), + ("quick brown", ["quick", SearchToken.And, "brown"]), + ("quick \t brown", ["quick", SearchToken.And, "brown"]), + ('"brown quick"', [Phrase(["brown", "quick"])]), + ("furphy OR fox", ["furphy", SearchToken.Or, "fox"]), + ("fox -brown", ["fox", SearchToken.Not, "brown"]), + ("- fox", [SearchToken.Not, "fox"]), + ('"fox" quick', [Phrase(["fox"]), SearchToken.And, "quick"]), + # No trailing double quoe. + ('"fox quick', ["fox", SearchToken.And, "quick"]), + ('"-fox quick', [SearchToken.Not, "fox", SearchToken.And, "quick"]), + ('" quick "', [Phrase(["quick"])]), + ( + 'q"uick brow"n', + [ + "q", + SearchToken.And, + Phrase(["uick", "brow"]), + SearchToken.And, + "n", + ], + ), + ( + '-"quick brown"', + [SearchToken.Not, Phrase(["quick", "brown"])], + ), + ) + + for query, expected in cases: + tokenized = _tokenize_query(query) + self.assertEqual( + tokenized, expected, f"{tokenized} != {expected} for {query}" + ) + + def _check_test_cases( + self, + store: DataStore, + cases: List[Tuple[str, Union[bool, Tuple[bool, bool]]]], + index=0, + ) -> None: + # Run all the test cases versus search_msgs + for query, expect_to_contain in cases: + if isinstance(expect_to_contain, tuple): + expect_to_contain = expect_to_contain[index] + + result = self.get_success( + store.search_msgs([self.room_id], query, ["content.body"]) + ) + self.assertEquals( + result["count"], + 1 if expect_to_contain else 0, + f"expected '{query}' to match '{self.PHRASE}'" + if expect_to_contain + else f"'{query}' unexpectedly matched '{self.PHRASE}'", + ) + self.assertEquals( + len(result["results"]), + 1 if expect_to_contain else 0, + "results array length should match count", + ) + + # Run them again versus search_rooms + for query, expect_to_contain in cases: + if isinstance(expect_to_contain, tuple): + expect_to_contain = expect_to_contain[index] + + result = self.get_success( + store.search_rooms([self.room_id], query, ["content.body"], 10) + ) + self.assertEquals( + result["count"], + 1 if expect_to_contain else 0, + f"expected '{query}' to match '{self.PHRASE}'" + if expect_to_contain + else f"'{query}' unexpectedly matched '{self.PHRASE}'", + ) + self.assertEquals( + len(result["results"]), + 1 if expect_to_contain else 0, + "results array length should match count", + ) + + def test_postgres_web_search_for_phrase(self): + """ + Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery. + This test is skipped unless the postgres instance supports websearch_to_tsquery. + """ + + store = self.hs.get_datastores().main + if not isinstance(store.database_engine, PostgresEngine): + raise SkipTest("Test only applies when postgres is used as the database") + + if store.database_engine.tsquery_func != "websearch_to_tsquery": + raise SkipTest( + "Test only applies when postgres supporting websearch_to_tsquery is used as the database" + ) + + self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES, index=0) + + def test_postgres_non_web_search_for_phrase(self): + """ + Test postgres searching for phrases without using web search, which is used when websearch_to_tsquery isn't + supported by the current postgres version. + """ + + store = self.hs.get_datastores().main + if not isinstance(store.database_engine, PostgresEngine): + raise SkipTest("Test only applies when postgres is used as the database") + + # Patch supports_websearch_to_tsquery to always return False to ensure we're testing the plainto_tsquery path. + with patch( + "synapse.storage.engines.postgres.PostgresEngine.tsquery_func", + new_callable=PropertyMock, + ) as supports_websearch_to_tsquery: + supports_websearch_to_tsquery.return_value = "plainto_tsquery" + self._check_test_cases( + store, self.COMMON_CASES + self.POSTGRES_CASES, index=1 + ) + + def test_sqlite_search(self): + """ + Test sqlite searching for phrases. + """ + store = self.hs.get_datastores().main + if not isinstance(store.database_engine, Sqlite3Engine): + raise SkipTest("Test only applies when sqlite is used as the database") + + self._check_test_cases(store, self.COMMON_CASES, index=0) -- cgit 1.5.1 From 67583281e3f8ea923eedbc56a4c85c7ba75d1582 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Oct 2022 09:58:12 -0400 Subject: Fix tests for change in PostgreSQL 14 behavior change. (#14310) PostgreSQL 14 changed the behavior of `websearch_to_tsquery` to improve some behaviour. The tests were hitting those edge-cases about handling of hanging double quotes. This fixes the tests to take into account the PostgreSQL version. --- changelog.d/14310.feature | 1 + synapse/storage/databases/main/search.py | 5 ++--- tests/storage/test_room_search.py | 16 ++++++++++++---- 3 files changed, 15 insertions(+), 7 deletions(-) create mode 100644 changelog.d/14310.feature (limited to 'tests/storage') diff --git a/changelog.d/14310.feature b/changelog.d/14310.feature new file mode 100644 index 0000000000..94c8a83212 --- /dev/null +++ b/changelog.d/14310.feature @@ -0,0 +1 @@ +Allow use of postgres and sqllite full-text search operators in search queries. \ No newline at end of file diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index a89fc54c2c..594b935614 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -824,9 +824,8 @@ def _tokenize_query(query: str) -> TokenList: in_phrase = False parts = deque(query.split('"')) for i, part in enumerate(parts): - # The contents inside double quotes is treated as a phrase, a trailing - # double quote is not implied. - in_phrase = bool(i % 2) and i != (len(parts) - 1) + # The contents inside double quotes is treated as a phrase. + in_phrase = bool(i % 2) # Pull out the individual words, discarding any non-word characters. words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE)) diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 9ddc19900a..868b5bee84 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -239,7 +239,6 @@ class MessageSearchTest(HomeserverTestCase): ("fox -nope", (True, False)), ("fox -brown", (False, True)), ('"fox" quick', True), - ('"fox quick', True), ('"quick brown', True), ('" quick "', True), ('" nope"', False), @@ -269,6 +268,15 @@ class MessageSearchTest(HomeserverTestCase): response = self.helper.send(self.room_id, self.PHRASE, tok=self.access_token) self.assertIn("event_id", response) + # The behaviour of a missing trailing double quote changed in PostgreSQL 14 + # from ignoring the initial double quote to treating it as a phrase. + main_store = homeserver.get_datastores().main + found = False + if isinstance(main_store.database_engine, PostgresEngine): + assert main_store.database_engine._version is not None + found = main_store.database_engine._version < 140000 + self.COMMON_CASES.append(('"fox quick', (found, True))) + def test_tokenize_query(self) -> None: """Test the custom logic to tokenize a user's query.""" cases = ( @@ -280,9 +288,9 @@ class MessageSearchTest(HomeserverTestCase): ("fox -brown", ["fox", SearchToken.Not, "brown"]), ("- fox", [SearchToken.Not, "fox"]), ('"fox" quick', [Phrase(["fox"]), SearchToken.And, "quick"]), - # No trailing double quoe. - ('"fox quick', ["fox", SearchToken.And, "quick"]), - ('"-fox quick', [SearchToken.Not, "fox", SearchToken.And, "quick"]), + # No trailing double quote. + ('"fox quick', [Phrase(["fox", "quick"])]), + ('"-fox quick', [Phrase(["-fox", "quick"])]), ('" quick "', [Phrase(["quick"])]), ( 'q"uick brow"n', -- cgit 1.5.1 From e9a4343cb2daa55503bb2a2d1431d83bf9773e68 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Nov 2022 09:55:34 -0500 Subject: Drop support for Postgres 10 in full text search code. (#14397) --- changelog.d/14397.removal | 1 + synapse/storage/databases/main/search.py | 50 +++++++++++------------ synapse/storage/engines/postgres.py | 16 -------- tests/storage/test_room_search.py | 69 ++++++++------------------------ 4 files changed, 41 insertions(+), 95 deletions(-) create mode 100644 changelog.d/14397.removal (limited to 'tests/storage') diff --git a/changelog.d/14397.removal b/changelog.d/14397.removal new file mode 100644 index 0000000000..e96b3de2bd --- /dev/null +++ b/changelog.d/14397.removal @@ -0,0 +1 @@ +Remove support for PostgreSQL 10. diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index e9588d1755..3fe433f66c 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -463,18 +463,17 @@ class SearchStore(SearchBackgroundUpdateStore): if isinstance(self.database_engine, PostgresEngine): search_query = search_term - tsquery_func = self.database_engine.tsquery_func - sql = f""" - SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank, + sql = """ + SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) AS rank, room_id, event_id FROM event_search - WHERE vector @@ {tsquery_func}('english', ?) + WHERE vector @@ websearch_to_tsquery('english', ?) """ args = [search_query, search_query] + args - count_sql = f""" + count_sql = """ SELECT room_id, count(*) as count FROM event_search - WHERE vector @@ {tsquery_func}('english', ?) + WHERE vector @@ websearch_to_tsquery('english', ?) """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): @@ -523,9 +522,7 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = await self._find_highlights_in_postgres( - search_query, events, tsquery_func - ) + highlights = await self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" @@ -604,18 +601,17 @@ class SearchStore(SearchBackgroundUpdateStore): if isinstance(self.database_engine, PostgresEngine): search_query = search_term - tsquery_func = self.database_engine.tsquery_func - sql = f""" - SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank, + sql = """ + SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank, origin_server_ts, stream_ordering, room_id, event_id FROM event_search - WHERE vector @@ {tsquery_func}('english', ?) AND + WHERE vector @@ websearch_to_tsquery('english', ?) AND """ args = [search_query, search_query] + args - count_sql = f""" + count_sql = """ SELECT room_id, count(*) as count FROM event_search - WHERE vector @@ {tsquery_func}('english', ?) AND + WHERE vector @@ websearch_to_tsquery('english', ?) AND """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): @@ -686,9 +682,7 @@ class SearchStore(SearchBackgroundUpdateStore): highlights = None if isinstance(self.database_engine, PostgresEngine): - highlights = await self._find_highlights_in_postgres( - search_query, events, tsquery_func - ) + highlights = await self._find_highlights_in_postgres(search_query, events) count_sql += " GROUP BY room_id" @@ -714,7 +708,7 @@ class SearchStore(SearchBackgroundUpdateStore): } async def _find_highlights_in_postgres( - self, search_query: str, events: List[EventBase], tsquery_func: str + self, search_query: str, events: List[EventBase] ) -> Set[str]: """Given a list of events and a search term, return a list of words that match from the content of the event. @@ -725,7 +719,6 @@ class SearchStore(SearchBackgroundUpdateStore): Args: search_query events: A list of events - tsquery_func: The tsquery_* function to use when making queries Returns: A set of strings. @@ -758,13 +751,16 @@ class SearchStore(SearchBackgroundUpdateStore): while stop_sel in value: stop_sel += ">" - query = f"SELECT ts_headline(?, {tsquery_func}('english', ?), %s)" % ( - _to_postgres_options( - { - "StartSel": start_sel, - "StopSel": stop_sel, - "MaxFragments": "50", - } + query = ( + "SELECT ts_headline(?, websearch_to_tsquery('english', ?), %s)" + % ( + _to_postgres_options( + { + "StartSel": start_sel, + "StopSel": stop_sel, + "MaxFragments": "50", + } + ) ) ) txn.execute(query, (value, search_query)) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 0c4fd88914..719a517336 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -170,22 +170,6 @@ class PostgresEngine( """Do we support the `RETURNING` clause in insert/update/delete?""" return True - @property - def tsquery_func(self) -> str: - """ - Selects a tsquery_* func to use. - - Ref: https://www.postgresql.org/docs/current/textsearch-controls.html - - Returns: - The function name. - """ - # Postgres 11 added support for websearch_to_tsquery. - assert self._version is not None - if self._version >= 110000: - return "websearch_to_tsquery" - return "plainto_tsquery" - def is_deadlock(self, error: Exception) -> bool: if isinstance(error, psycopg2.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 868b5bee84..ef850daa73 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import List, Tuple from unittest.case import SkipTest -from unittest.mock import PropertyMock, patch from twisted.test.proto_helpers import MemoryReactor @@ -220,10 +219,8 @@ class MessageSearchTest(HomeserverTestCase): PHRASE = "the quick brown fox jumps over the lazy dog" - # Each entry is a search query, followed by either a boolean of whether it is - # in the phrase OR a tuple of booleans: whether it matches using websearch - # and using plain search. - COMMON_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [ + # Each entry is a search query, followed by a boolean of whether it is in the phrase. + COMMON_CASES = [ ("nope", False), ("brown", True), ("quick brown", True), @@ -231,13 +228,13 @@ class MessageSearchTest(HomeserverTestCase): ("quick \t brown", True), ("jump", True), ("brown nope", False), - ('"brown quick"', (False, True)), + ('"brown quick"', False), ('"jumps over"', True), - ('"quick fox"', (False, True)), + ('"quick fox"', False), ("nope OR doublenope", False), - ("furphy OR fox", (True, False)), - ("fox -nope", (True, False)), - ("fox -brown", (False, True)), + ("furphy OR fox", True), + ("fox -nope", True), + ("fox -brown", False), ('"fox" quick', True), ('"quick brown', True), ('" quick "', True), @@ -246,11 +243,11 @@ class MessageSearchTest(HomeserverTestCase): # TODO Test non-ASCII cases. # Case that fail on SQLite. - POSTGRES_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [ + POSTGRES_CASES = [ # SQLite treats NOT as a binary operator. - ("- fox", (False, True)), - ("- nope", (True, False)), - ('"-fox quick', (False, True)), + ("- fox", False), + ("- nope", True), + ('"-fox quick', False), # PostgreSQL skips stop words. ('"the quick brown"', True), ('"over lazy"', True), @@ -275,7 +272,7 @@ class MessageSearchTest(HomeserverTestCase): if isinstance(main_store.database_engine, PostgresEngine): assert main_store.database_engine._version is not None found = main_store.database_engine._version < 140000 - self.COMMON_CASES.append(('"fox quick', (found, True))) + self.COMMON_CASES.append(('"fox quick', found)) def test_tokenize_query(self) -> None: """Test the custom logic to tokenize a user's query.""" @@ -315,16 +312,10 @@ class MessageSearchTest(HomeserverTestCase): ) def _check_test_cases( - self, - store: DataStore, - cases: List[Tuple[str, Union[bool, Tuple[bool, bool]]]], - index=0, + self, store: DataStore, cases: List[Tuple[str, bool]] ) -> None: # Run all the test cases versus search_msgs for query, expect_to_contain in cases: - if isinstance(expect_to_contain, tuple): - expect_to_contain = expect_to_contain[index] - result = self.get_success( store.search_msgs([self.room_id], query, ["content.body"]) ) @@ -343,9 +334,6 @@ class MessageSearchTest(HomeserverTestCase): # Run them again versus search_rooms for query, expect_to_contain in cases: - if isinstance(expect_to_contain, tuple): - expect_to_contain = expect_to_contain[index] - result = self.get_success( store.search_rooms([self.room_id], query, ["content.body"], 10) ) @@ -366,38 +354,15 @@ class MessageSearchTest(HomeserverTestCase): """ Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery. This test is skipped unless the postgres instance supports websearch_to_tsquery. - """ - - store = self.hs.get_datastores().main - if not isinstance(store.database_engine, PostgresEngine): - raise SkipTest("Test only applies when postgres is used as the database") - - if store.database_engine.tsquery_func != "websearch_to_tsquery": - raise SkipTest( - "Test only applies when postgres supporting websearch_to_tsquery is used as the database" - ) - self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES, index=0) - - def test_postgres_non_web_search_for_phrase(self): - """ - Test postgres searching for phrases without using web search, which is used when websearch_to_tsquery isn't - supported by the current postgres version. + See https://www.postgresql.org/docs/current/textsearch-controls.html """ store = self.hs.get_datastores().main if not isinstance(store.database_engine, PostgresEngine): raise SkipTest("Test only applies when postgres is used as the database") - # Patch supports_websearch_to_tsquery to always return False to ensure we're testing the plainto_tsquery path. - with patch( - "synapse.storage.engines.postgres.PostgresEngine.tsquery_func", - new_callable=PropertyMock, - ) as supports_websearch_to_tsquery: - supports_websearch_to_tsquery.return_value = "plainto_tsquery" - self._check_test_cases( - store, self.COMMON_CASES + self.POSTGRES_CASES, index=1 - ) + self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES) def test_sqlite_search(self): """ @@ -407,4 +372,4 @@ class MessageSearchTest(HomeserverTestCase): if not isinstance(store.database_engine, Sqlite3Engine): raise SkipTest("Test only applies when sqlite is used as the database") - self._check_test_cases(store, self.COMMON_CASES, index=0) + self._check_test_cases(store, self.COMMON_CASES) -- cgit 1.5.1 From 882277008c7b43ab26e3445ab94a38aa25ad0965 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 16 Nov 2022 15:01:22 +0000 Subject: Fix background updates failing to add unique indexes on receipts (#14453) As part of the database migration to support threaded receipts, there is a possible window in between `73/08thread_receipts_non_null.sql.postgres` removing the original unique constraints on `receipts_linearized` and `receipts_graph` and the `reeipts_linearized_unique_index` and `receipts_graph_unique_index` background updates from `72/08thread_receipts.sql` completing where the unique constraints on `receipts_linearized` and `receipts_graph` are missing. Any emulated upserts on these tables must therefore be performed with a lock held, otherwise duplicate rows can end up in the tables when there are concurrent emulated upserts. Fix the missing lock. Note that emulated upserts no longer happen by default on sqlite, since the minimum supported version of sqlite supports native upserts by default now. Finally, clean up any duplicate receipts that may have crept in before trying to create the `receipts_graph_unique_index` and `receipts_linearized_unique_index` unique indexes. Signed-off-by: Sean Quah --- changelog.d/14453.bugfix | 1 + synapse/storage/databases/main/receipts.py | 171 ++++++++++++++++++--- tests/storage/databases/main/test_receipts.py | 209 ++++++++++++++++++++++++++ 3 files changed, 357 insertions(+), 24 deletions(-) create mode 100644 changelog.d/14453.bugfix create mode 100644 tests/storage/databases/main/test_receipts.py (limited to 'tests/storage') diff --git a/changelog.d/14453.bugfix b/changelog.d/14453.bugfix new file mode 100644 index 0000000000..4969e5450c --- /dev/null +++ b/changelog.d/14453.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0 where the background updates to add non-thread unique indexes on receipts could fail when upgrading from 1.67.0 or earlier. diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index dc6989527e..fbf27497ec 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -113,24 +113,6 @@ class ReceiptsWorkerStore(SQLBaseStore): prefilled_cache=receipts_stream_prefill, ) - self.db_pool.updates.register_background_index_update( - "receipts_linearized_unique_index", - index_name="receipts_linearized_unique_index", - table="receipts_linearized", - columns=["room_id", "receipt_type", "user_id"], - where_clause="thread_id IS NULL", - unique=True, - ) - - self.db_pool.updates.register_background_index_update( - "receipts_graph_unique_index", - index_name="receipts_graph_unique_index", - table="receipts_graph", - columns=["room_id", "receipt_type", "user_id"], - where_clause="thread_id IS NULL", - unique=True, - ) - def get_max_receipt_stream_id(self) -> int: """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() @@ -702,9 +684,6 @@ class ReceiptsWorkerStore(SQLBaseStore): "data": json_encoder.encode(data), }, where_clause=where_clause, - # receipts_linearized has a unique constraint on - # (user_id, room_id, receipt_type), so no need to lock - lock=False, ) return rx_ts @@ -862,14 +841,13 @@ class ReceiptsWorkerStore(SQLBaseStore): "data": json_encoder.encode(data), }, where_clause=where_clause, - # receipts_graph has a unique constraint on - # (user_id, room_id, receipt_type), so no need to lock - lock=False, ) class ReceiptsBackgroundUpdateStore(SQLBaseStore): POPULATE_RECEIPT_EVENT_STREAM_ORDERING = "populate_event_stream_ordering" + RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME = "receipts_linearized_unique_index" + RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME = "receipts_graph_unique_index" def __init__( self, @@ -883,6 +861,14 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING, self._populate_receipt_event_stream_ordering, ) + self.db_pool.updates.register_background_update_handler( + self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME, + self._background_receipts_linearized_unique_index, + ) + self.db_pool.updates.register_background_update_handler( + self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME, + self._background_receipts_graph_unique_index, + ) async def _populate_receipt_event_stream_ordering( self, progress: JsonDict, batch_size: int @@ -938,6 +924,143 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): return batch_size + async def _create_receipts_index(self, index_name: str, table: str) -> None: + """Adds a unique index on `(room_id, receipt_type, user_id)` to the given + receipts table, for non-thread receipts.""" + + def _create_index(conn: LoggingDatabaseConnection) -> None: + conn.rollback() + + # we have to set autocommit, because postgres refuses to + # CREATE INDEX CONCURRENTLY without it. + if isinstance(self.database_engine, PostgresEngine): + conn.set_session(autocommit=True) + + try: + c = conn.cursor() + + # Now that the duplicates are gone, we can create the index. + concurrently = ( + "CONCURRENTLY" + if isinstance(self.database_engine, PostgresEngine) + else "" + ) + sql = f""" + CREATE UNIQUE INDEX {concurrently} {index_name} + ON {table}(room_id, receipt_type, user_id) + WHERE thread_id IS NULL + """ + c.execute(sql) + finally: + if isinstance(self.database_engine, PostgresEngine): + conn.set_session(autocommit=False) + + await self.db_pool.runWithConnection(_create_index) + + async def _background_receipts_linearized_unique_index( + self, progress: dict, batch_size: int + ) -> int: + """Removes duplicate receipts and adds a unique index on + `(room_id, receipt_type, user_id)` to `receipts_linearized`, for non-thread + receipts.""" + + def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: + # Identify any duplicate receipts arising from + # https://github.com/matrix-org/synapse/issues/14406. + # We expect the following query to use the per-thread receipt index and take + # less than a minute. + sql = """ + SELECT MAX(stream_id), room_id, receipt_type, user_id + FROM receipts_linearized + WHERE thread_id IS NULL + GROUP BY room_id, receipt_type, user_id + HAVING COUNT(*) > 1 + """ + txn.execute(sql) + duplicate_keys = cast(List[Tuple[int, str, str, str]], list(txn)) + + # Then remove duplicate receipts, keeping the one with the highest + # `stream_id`. There should only be a single receipt with any given + # `stream_id`. + for max_stream_id, room_id, receipt_type, user_id in duplicate_keys: + sql = """ + DELETE FROM receipts_linearized + WHERE + room_id = ? AND + receipt_type = ? AND + user_id = ? AND + thread_id IS NULL AND + stream_id < ? + """ + txn.execute(sql, (room_id, receipt_type, user_id, max_stream_id)) + + await self.db_pool.runInteraction( + self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME, + _remote_duplicate_receipts_txn, + ) + + await self._create_receipts_index( + "receipts_linearized_unique_index", + "receipts_linearized", + ) + + await self.db_pool.updates._end_background_update( + self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME + ) + + return 1 + + async def _background_receipts_graph_unique_index( + self, progress: dict, batch_size: int + ) -> int: + """Removes duplicate receipts and adds a unique index on + `(room_id, receipt_type, user_id)` to `receipts_graph`, for non-thread + receipts.""" + + def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: + # Identify any duplicate receipts arising from + # https://github.com/matrix-org/synapse/issues/14406. + # We expect the following query to use the per-thread receipt index and take + # less than a minute. + sql = """ + SELECT room_id, receipt_type, user_id FROM receipts_graph + WHERE thread_id IS NULL + GROUP BY room_id, receipt_type, user_id + HAVING COUNT(*) > 1 + """ + txn.execute(sql) + duplicate_keys = cast(List[Tuple[str, str, str]], list(txn)) + + # Then remove all duplicate receipts. + # We could be clever and try to keep the latest receipt out of every set of + # duplicates, but it's far simpler to remove them all. + for room_id, receipt_type, user_id in duplicate_keys: + sql = """ + DELETE FROM receipts_graph + WHERE + room_id = ? AND + receipt_type = ? AND + user_id = ? AND + thread_id IS NULL + """ + txn.execute(sql, (room_id, receipt_type, user_id)) + + await self.db_pool.runInteraction( + self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME, + _remote_duplicate_receipts_txn, + ) + + await self._create_receipts_index( + "receipts_graph_unique_index", + "receipts_graph", + ) + + await self.db_pool.updates._end_background_update( + self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME + ) + + return 1 + class ReceiptsStore(ReceiptsWorkerStore, ReceiptsBackgroundUpdateStore): pass diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py new file mode 100644 index 0000000000..c4f12d81d7 --- /dev/null +++ b/tests/storage/databases/main/test_receipts.py @@ -0,0 +1,209 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Sequence, Tuple + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.storage.database import LoggingTransaction +from synapse.util import Clock + +from tests.unittest import HomeserverTestCase + + +class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + self.store = hs.get_datastores().main + self.user_id = self.register_user("foo", "pass") + self.token = self.login("foo", "pass") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) + self.other_room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + def _test_background_receipts_unique_index( + self, + update_name: str, + index_name: str, + table: str, + receipts: Dict[Tuple[str, str, str], Sequence[Dict[str, Any]]], + expected_unique_receipts: Dict[Tuple[str, str, str], Optional[Dict[str, Any]]], + ): + """Test that the background update to uniqueify non-thread receipts in + the given receipts table works properly. + + Args: + update_name: The name of the background update to test. + index_name: The name of the index that the background update creates. + table: The table of receipts that the background update fixes. + receipts: The test data containing duplicate receipts. + A list of receipt rows to insert, grouped by + `(room_id, receipt_type, user_id)`. + expected_unique_receipts: A dictionary of `(room_id, receipt_type, user_id)` + keys and expected receipt key-values after duplicate receipts have been + removed. + """ + # First, undo the background update. + def drop_receipts_unique_index(txn: LoggingTransaction) -> None: + txn.execute(f"DROP INDEX IF EXISTS {index_name}") + + self.get_success( + self.store.db_pool.runInteraction( + "drop_receipts_unique_index", + drop_receipts_unique_index, + ) + ) + + # Populate the receipts table, including duplicates. + for (room_id, receipt_type, user_id), rows in receipts.items(): + for row in rows: + self.get_success( + self.store.db_pool.simple_insert( + table, + { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + "thread_id": None, + "data": "{}", + **row, + }, + ) + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": update_name, + "progress_json": "{}", + }, + ) + ) + + self.store.db_pool.updates._all_done = False + + self.wait_for_background_updates() + + # Check that the remaining receipts match expectations. + for ( + room_id, + receipt_type, + user_id, + ), expected_row in expected_unique_receipts.items(): + # Include the receipt key in the returned columns, for more informative + # assertion messages. + columns = ["room_id", "receipt_type", "user_id"] + if expected_row is not None: + columns += expected_row.keys() + + rows = self.get_success( + self.store.db_pool.simple_select_list( + table=table, + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + # `simple_select_onecol` does not support NULL filters, + # so skip the filter on `thread_id`. + }, + retcols=columns, + desc="get_receipt", + ) + ) + + if expected_row is not None: + self.assertEqual( + len(rows), + 1, + f"Background update did not leave behind latest receipt in {table}", + ) + self.assertEqual( + rows[0], + { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + **expected_row, + }, + ) + else: + self.assertEqual( + len(rows), + 0, + f"Background update did not remove all duplicate receipts from {table}", + ) + + def test_background_receipts_linearized_unique_index(self): + """Test that the background update to uniqueify non-thread receipts in + `receipts_linearized` works properly. + """ + self._test_background_receipts_unique_index( + "receipts_linearized_unique_index", + "receipts_linearized_unique_index", + "receipts_linearized", + receipts={ + (self.room_id, "m.read", self.user_id): [ + {"stream_id": 5, "event_id": "$some_event"}, + {"stream_id": 6, "event_id": "$some_event"}, + ], + (self.other_room_id, "m.read", self.user_id): [ + {"stream_id": 7, "event_id": "$some_event"} + ], + }, + expected_unique_receipts={ + (self.room_id, "m.read", self.user_id): {"stream_id": 6}, + (self.other_room_id, "m.read", self.user_id): {"stream_id": 7}, + }, + ) + + def test_background_receipts_graph_unique_index(self): + """Test that the background update to uniqueify non-thread receipts in + `receipts_graph` works properly. + """ + self._test_background_receipts_unique_index( + "receipts_graph_unique_index", + "receipts_graph_unique_index", + "receipts_graph", + receipts={ + (self.room_id, "m.read", self.user_id): [ + { + "event_ids": '["$some_event"]', + }, + { + "event_ids": '["$some_event"]', + }, + ], + (self.other_room_id, "m.read", self.user_id): [ + { + "event_ids": '["$some_event"]', + } + ], + }, + expected_unique_receipts={ + (self.room_id, "m.read", self.user_id): None, + (self.other_room_id, "m.read", self.user_id): { + "event_ids": '["$some_event"]' + }, + }, + ) -- cgit 1.5.1 From 115f0eb2334b13665e5c112bd87f95ea393c9047 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 16 Nov 2022 22:16:46 +0000 Subject: Reintroduce #14376, with bugfix for monoliths (#14468) * Add tests for StreamIdGenerator * Drive-by: annotate all defs * Revert "Revert "Remove slaved id tracker (#14376)" (#14463)" This reverts commit d63814fd736fed5d3d45ff3af5e6d3bfae50c439, which in turn reverted 36097e88c4da51fce6556a58c49bd675f4cf20ab. This restores the latter. * Fix StreamIdGenerator not handling unpersisted IDs Spotted by @erikjohnston. Closes #14456. * Changelog Co-authored-by: Nick Mills-Barrett Co-authored-by: Erik Johnston --- changelog.d/14376.misc | 1 + changelog.d/14468.misc | 1 + mypy.ini | 3 + synapse/replication/slave/__init__.py | 13 -- synapse/replication/slave/storage/__init__.py | 13 -- .../slave/storage/_slaved_id_tracker.py | 50 ------- synapse/storage/databases/main/account_data.py | 30 ++-- synapse/storage/databases/main/devices.py | 36 ++--- synapse/storage/databases/main/events_worker.py | 35 ++--- synapse/storage/databases/main/push_rule.py | 17 +-- synapse/storage/databases/main/pusher.py | 24 ++- synapse/storage/databases/main/receipts.py | 18 +-- synapse/storage/util/id_generators.py | 13 +- tests/storage/test_id_generators.py | 162 +++++++++++++++++++-- 14 files changed, 230 insertions(+), 186 deletions(-) create mode 100644 changelog.d/14376.misc create mode 100644 changelog.d/14468.misc delete mode 100644 synapse/replication/slave/__init__.py delete mode 100644 synapse/replication/slave/storage/__init__.py delete mode 100644 synapse/replication/slave/storage/_slaved_id_tracker.py (limited to 'tests/storage') diff --git a/changelog.d/14376.misc b/changelog.d/14376.misc new file mode 100644 index 0000000000..2ca326fea6 --- /dev/null +++ b/changelog.d/14376.misc @@ -0,0 +1 @@ +Remove old stream ID tracking code. Contributed by Nick @Beeper (@fizzadar). diff --git a/changelog.d/14468.misc b/changelog.d/14468.misc new file mode 100644 index 0000000000..2ca326fea6 --- /dev/null +++ b/changelog.d/14468.misc @@ -0,0 +1 @@ +Remove old stream ID tracking code. Contributed by Nick @Beeper (@fizzadar). diff --git a/mypy.ini b/mypy.ini index 8f1141a239..53512b2584 100644 --- a/mypy.ini +++ b/mypy.ini @@ -117,6 +117,9 @@ disallow_untyped_defs = True [mypy-tests.state.test_profile] disallow_untyped_defs = True +[mypy-tests.storage.test_id_generators] +disallow_untyped_defs = True + [mypy-tests.storage.test_profile] disallow_untyped_defs = True diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py deleted file mode 100644 index f43a360a80..0000000000 --- a/synapse/replication/slave/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py deleted file mode 100644 index f43a360a80..0000000000 --- a/synapse/replication/slave/storage/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py deleted file mode 100644 index 8f3f953ed4..0000000000 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Optional, Tuple - -from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id - - -class SlavedIdTracker(AbstractStreamIdTracker): - """Tracks the "current" stream ID of a stream with a single writer. - - See `AbstractStreamIdTracker` for more details. - - Note that this class does not work correctly when there are multiple - writers. - """ - - def __init__( - self, - db_conn: LoggingDatabaseConnection, - table: str, - column: str, - extra_tables: Optional[List[Tuple[str, str]]] = None, - step: int = 1, - ): - self.step = step - self._current = _load_current_id(db_conn, table, column, step) - if extra_tables: - for table, column in extra_tables: - self.advance(None, _load_current_id(db_conn, table, column)) - - def advance(self, instance_name: Optional[str], new_id: int) -> None: - self._current = (max if self.step > 0 else min)(self._current, new_id) - - def get_current_token(self) -> int: - return self._current - - def get_current_token_for_writer(self, instance_name: str) -> int: - return self.get_current_token() diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index c38b8a9e5a..282687ebce 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.api.constants import AccountDataTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( @@ -68,12 +67,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # to write account data. A value of `True` implies that `_account_data_id_gen` # is an `AbstractStreamIdGenerator` and not just a tracker. self._account_data_id_gen: AbstractStreamIdTracker + self._can_write_to_account_data = ( + self._instance_name in hs.config.worker.writers.account_data + ) if isinstance(database.engine, PostgresEngine): - self._can_write_to_account_data = ( - self._instance_name in hs.config.worker.writers.account_data - ) - self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, @@ -95,21 +93,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if self._instance_name in hs.config.worker.writers.account_data: - self._can_write_to_account_data = True - self._account_data_id_gen = StreamIdGenerator( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) - else: - self._account_data_id_gen = SlavedIdTracker( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) + self._account_data_id_gen = StreamIdGenerator( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + is_writer=self._instance_name in hs.config.worker.writers.account_data, + ) account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index e114c733d1..57230df5ae 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -38,7 +38,6 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -86,28 +85,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) - else: - self._device_list_id_gen = SlavedIdTracker( - db_conn, - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ], - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ], + is_writer=hs.config.worker.worker_app is None, + ) # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker). diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8a104f7e93..01e935edef 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -59,7 +59,6 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -213,26 +212,20 @@ class EventsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.events: - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) - else: - self._stream_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering" - ) - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + is_writer=hs.get_instance_name() in hs.config.worker.writers.events, + ) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 8ae10f6127..12ad44dbb3 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -30,7 +30,6 @@ from typing import ( from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushRulesStream from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -111,14 +110,14 @@ class PushRulesWorkerStore( ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, "push_rules_stream", "stream_id" - ) - else: - self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id" - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "push_rules_stream", + "stream_id", + is_writer=hs.config.worker.worker_app is None, + ) push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 4a01562d45..fee37b9ce4 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.push import PusherConfig, ThrottleParams -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import PushersStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -59,20 +58,15 @@ class PusherWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - if hs.config.worker.worker_app is None: - self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - ) - else: - self._pushers_id_gen = SlavedIdTracker( - db_conn, - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - ) + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. + self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + db_conn, + "pushers", + "id", + extra_tables=[("deleted_pushers", "stream_id")], + is_writer=hs.config.worker.worker_app is None, + ) self.db_pool.updates.register_background_update_handler( "remove_deactivated_pushers", diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index fbf27497ec..a580e4bdda 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -27,7 +27,6 @@ from typing import ( ) from synapse.api.constants import EduTypes -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -61,6 +60,9 @@ class ReceiptsWorkerStore(SQLBaseStore): hs: "HomeServer", ): self._instance_name = hs.get_instance_name() + + # In the worker store this is an ID tracker which we overwrite in the non-worker + # class below that is used on the main process. self._receipts_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): @@ -87,14 +89,12 @@ class ReceiptsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.receipts: - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" - ) - else: - self._receipts_id_gen = SlavedIdTracker( - db_conn, "receipts_linearized", "stream_id" - ) + self._receipts_id_gen = StreamIdGenerator( + db_conn, + "receipts_linearized", + "stream_id", + is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts, + ) super().__init__(database, db_conn, hs) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 2dfe4c0b66..0d7108f01b 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -186,11 +186,13 @@ class StreamIdGenerator(AbstractStreamIdGenerator): column: str, extra_tables: Iterable[Tuple[str, str]] = (), step: int = 1, + is_writer: bool = True, ) -> None: assert step != 0 self._lock = threading.Lock() self._step: int = step self._current: int = _load_current_id(db_conn, table, column, step) + self._is_writer = is_writer for table, column in extra_tables: self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) @@ -204,9 +206,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator): self._unfinished_ids: OrderedDict[int, int] = OrderedDict() def advance(self, instance_name: str, new_id: int) -> None: - # `StreamIdGenerator` should only be used when there is a single writer, - # so replication should never happen. - raise Exception("Replication is not supported by StreamIdGenerator") + # Advance should never be called on a writer instance, only over replication + if self._is_writer: + raise Exception("Replication is not supported by writer StreamIdGenerator") + + self._current = (max if self._step > 0 else min)(self._current, new_id) def get_next(self) -> AsyncContextManager[int]: with self._lock: @@ -249,6 +253,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return _AsyncCtxManagerWrapper(manager()) def get_current_token(self) -> int: + if not self._is_writer: + return self._current + with self._lock: if self._unfinished_ids: return next(iter(self._unfinished_ids)) - self._step diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 2d8d1f860f..d6a2b8d274 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -16,15 +16,157 @@ from typing import List, Optional from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import IncorrectDatabaseSetup -from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.util import Clock from tests.unittest import HomeserverTestCase from tests.utils import USE_POSTGRES_FOR_TESTS +class StreamIdGeneratorTestCase(HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn: LoggingTransaction) -> None: + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + data TEXT + ); + """ + ) + txn.execute("INSERT INTO foobar VALUES (123, 'hello world');") + + def _create_id_generator(self) -> StreamIdGenerator: + def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator: + return StreamIdGenerator( + db_conn=conn, + table="foobar", + column="stream_id", + ) + + return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) + + def test_initial_value(self) -> None: + """Check that we read the current token from the DB.""" + id_gen = self._create_id_generator() + self.assertEqual(id_gen.get_current_token(), 123) + + def test_single_gen_next(self) -> None: + """Check that we correctly increment the current token from the DB.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + async with id_gen.get_next() as next_id: + # We haven't persisted `next_id` yet; current token is still 123 + self.assertEqual(id_gen.get_current_token(), 123) + # But we did learn what the next value is + self.assertEqual(next_id, 124) + + # Once the context manager closes we assume that the `next_id` has been + # written to the DB. + self.assertEqual(id_gen.get_current_token(), 124) + + self.get_success(test_gen_next()) + + def test_multiple_gen_nexts(self) -> None: + """Check that we handle overlapping calls to gen_next sensibly.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request three new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + self.assertEqual(await ctx3.__aenter__(), 126) + + # None are persisted: current token unchanged. + self.assertEqual(id_gen.get_current_token(), 123) + + # Persist each in turn. + await ctx1.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 124) + await ctx2.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 125) + await ctx3.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 126) + + self.get_success(test_gen_next()) + + def test_multiple_gen_nexts_closed_in_different_order(self) -> None: + """Check that we handle overlapping calls to gen_next, even when their IDs + created and persisted in different orders.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request three new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + self.assertEqual(await ctx3.__aenter__(), 126) + + # None are persisted: current token unchanged. + self.assertEqual(id_gen.get_current_token(), 123) + + # Persist them in a different order, starting with 126 from ctx3. + await ctx3.__aexit__(None, None, None) + # We haven't persisted 124 from ctx1 yet---current token is still 123. + self.assertEqual(id_gen.get_current_token(), 123) + + # Now persist 124 from ctx1. + await ctx1.__aexit__(None, None, None) + # Current token is then 124, waiting for 125 to be persisted. + self.assertEqual(id_gen.get_current_token(), 124) + + # Finally persist 125 from ctx2. + await ctx2.__aexit__(None, None, None) + # Current token is then 126 (skipping over 125). + self.assertEqual(id_gen.get_current_token(), 126) + + self.get_success(test_gen_next()) + + def test_gen_next_while_still_waiting_for_persistence(self) -> None: + """Check that we handle overlapping calls to gen_next.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request two new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + + # Persist ctx2 first. + await ctx2.__aexit__(None, None, None) + # Still waiting on ctx1's ID to be persisted. + self.assertEqual(id_gen.get_current_token(), 123) + + # Now request a third stream ID. It should be 126 (the smallest ID that + # we've not yet handed out.) + self.assertEqual(await ctx3.__aenter__(), 126) + + self.get_success(test_gen_next()) + + class MultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" @@ -48,9 +190,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -446,7 +588,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("master", 3) # Now we add a row *without* updating the stream ID - def _insert(txn): + def _insert(txn: Cursor) -> None: txn.execute("INSERT INTO foobar VALUES (26, 'master')") self.get_success(self.db_pool.runInteraction("_insert", _insert)) @@ -481,9 +623,9 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -617,9 +759,9 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers: Optional[List[str]] = None + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: - def _create(conn): + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( conn, self.db_pool, @@ -641,7 +783,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): instance_name: str, number: int, update_stream_table: bool = True, - ): + ) -> None: """Insert N rows as the given instance, inserting with stream IDs pulled from the postgres sequence. """ -- cgit 1.5.1 From 9cae44f49e6bf4f6b8a20ab11a65da417bb1565f Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 22 Nov 2022 16:46:52 +0000 Subject: Track unconverted device list outbound pokes using a position instead (#14516) When a local device list change is added to `device_lists_changes_in_room`, the `converted_to_destinations` flag is set to `FALSE` and the `_handle_new_device_update_async` background process is started. This background process looks for unconverted rows in `device_lists_changes_in_room`, copies them to `device_lists_outbound_pokes` and updates the flag. To update the `converted_to_destinations` flag, the database performs a `DELETE` and `INSERT` internally, which fragments the table. To avoid this, track unconverted rows using a `(stream ID, room ID)` position instead of the flag. From now on, the `converted_to_destinations` column indicates rows that need converting to outbound pokes, but does not indicate whether the conversion has already taken place. Closes #14037. Signed-off-by: Sean Quah --- changelog.d/14516.misc | 1 + synapse/handlers/device.py | 30 +++++- synapse/storage/database.py | 13 +-- synapse/storage/databases/main/devices.py | 107 +++++++++++++-------- .../73/12refactor_device_list_outbound_pokes.sql | 53 ++++++++++ tests/storage/test_devices.py | 3 +- 6 files changed, 158 insertions(+), 49 deletions(-) create mode 100644 changelog.d/14516.misc create mode 100644 synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql (limited to 'tests/storage') diff --git a/changelog.d/14516.misc b/changelog.d/14516.misc new file mode 100644 index 0000000000..51666c6ffc --- /dev/null +++ b/changelog.d/14516.misc @@ -0,0 +1 @@ +Refactor conversion of device list changes in room to outbound pokes to track unconverted rows using a `(stream ID, room ID)` position instead of updating the `converted_to_destinations` flag on every row. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index c597639a7f..da3ddafeae 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -682,13 +682,33 @@ class DeviceHandler(DeviceWorkerHandler): hosts_already_sent_to: Set[str] = set() try: + stream_id, room_id = await self.store.get_device_change_last_converted_pos() + while True: self._handle_new_device_update_new_data = False - rows = await self.store.get_uncoverted_outbound_room_pokes() + max_stream_id = self.store.get_device_stream_token() + rows = await self.store.get_uncoverted_outbound_room_pokes( + stream_id, room_id + ) if not rows: # If the DB returned nothing then there is nothing left to # do, *unless* a new device list update happened during the # DB query. + + # Advance `(stream_id, room_id)`. + # `max_stream_id` comes from *before* the query for unconverted + # rows, which means that any unconverted rows must have a larger + # stream ID. + if max_stream_id > stream_id: + stream_id, room_id = max_stream_id, "" + await self.store.set_device_change_last_converted_pos( + stream_id, room_id + ) + else: + assert max_stream_id == stream_id + # Avoid moving `room_id` backwards. + pass + if self._handle_new_device_update_new_data: continue else: @@ -718,7 +738,6 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=device_id, room_id=room_id, - stream_id=stream_id, hosts=hosts, context=opentracing_context, ) @@ -752,6 +771,12 @@ class DeviceHandler(DeviceWorkerHandler): hosts_already_sent_to.update(hosts) current_stream_id = stream_id + # Advance `(stream_id, room_id)`. + _, _, room_id, stream_id, _ = rows[-1] + await self.store.set_device_change_last_converted_pos( + stream_id, room_id + ) + finally: self._handle_new_device_update_is_processing = False @@ -834,7 +859,6 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=device_id, room_id=room_id, - stream_id=None, hosts=potentially_changed_hosts, context=None, ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0dc44b246c..a14b13aec8 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -2075,13 +2075,14 @@ class DatabasePool: retcols: Collection[str], allow_none: bool = False, ) -> Optional[Dict[str, Any]]: - select_sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) + select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table) + + if keyvalues: + select_sql += " WHERE %s" % (" AND ".join("%s = ?" % k for k in keyvalues),) + txn.execute(select_sql, list(keyvalues.values())) + else: + txn.execute(select_sql) - txn.execute(select_sql, list(keyvalues.values())) row = txn.fetchone() if not row: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 57230df5ae..37629115ab 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -2008,27 +2008,48 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def get_uncoverted_outbound_room_pokes( - self, limit: int = 10 + self, start_stream_id: int, start_room_id: str, limit: int = 10 ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: """Get device list changes by room that have not yet been handled and written to `device_lists_outbound_pokes`. + Args: + start_stream_id: Together with `start_room_id`, indicates the position after + which to return device list changes. + start_room_id: Together with `start_stream_id`, indicates the position after + which to return device list changes. + limit: The maximum number of device list changes to return. + Returns: - A list of user ID, device ID, room ID, stream ID and optional opentracing context. + A list of user ID, device ID, room ID, stream ID and optional opentracing + context, in order of ascending (stream ID, room ID). """ sql = """ SELECT user_id, device_id, room_id, stream_id, opentracing_context FROM device_lists_changes_in_room - WHERE NOT converted_to_destinations - ORDER BY stream_id + WHERE + (stream_id, room_id) > (?, ?) AND + stream_id <= ? AND + NOT converted_to_destinations + ORDER BY stream_id ASC, room_id ASC LIMIT ? """ def get_uncoverted_outbound_room_pokes_txn( txn: LoggingTransaction, ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: - txn.execute(sql, (limit,)) + txn.execute( + sql, + ( + start_stream_id, + start_room_id, + # Avoid returning rows if there may be uncommitted device list + # changes with smaller stream IDs. + self._device_list_id_gen.get_current_token(), + limit, + ), + ) return [ ( @@ -2050,49 +2071,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_id: str, room_id: str, - stream_id: Optional[int], hosts: Collection[str], context: Optional[Dict[str, str]], ) -> None: """Queue the device update to be sent to the given set of hosts, calculated from the room ID. - - Marks the associated row in `device_lists_changes_in_room` as handled, - if `stream_id` is provided. """ + if not hosts: + return def add_device_list_outbound_pokes_txn( txn: LoggingTransaction, stream_ids: List[int] ) -> None: - if hosts: - self._add_device_outbound_poke_to_stream_txn( - txn, - user_id=user_id, - device_id=device_id, - hosts=hosts, - stream_ids=stream_ids, - context=context, - ) - - if stream_id: - self.db_pool.simple_update_txn( - txn, - table="device_lists_changes_in_room", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - "stream_id": stream_id, - "room_id": room_id, - }, - updatevalues={"converted_to_destinations": True}, - ) - - if not hosts: - # If there are no hosts then we don't try and generate stream IDs. - return await self.db_pool.runInteraction( - "add_device_list_outbound_pokes", - add_device_list_outbound_pokes_txn, - [], + self._add_device_outbound_poke_to_stream_txn( + txn, + user_id=user_id, + device_id=device_id, + hosts=hosts, + stream_ids=stream_ids, + context=context, ) async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: @@ -2156,3 +2153,37 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): "get_pending_remote_device_list_updates_for_room", get_pending_remote_device_list_updates_for_room_txn, ) + + async def get_device_change_last_converted_pos(self) -> Tuple[int, str]: + """ + Get the position of the last row in `device_list_changes_in_room` that has been + converted to `device_lists_outbound_pokes`. + + Rows with a strictly greater position where `converted_to_destinations` is + `FALSE` have not been converted. + """ + + row = await self.db_pool.simple_select_one( + table="device_lists_changes_converted_stream_position", + keyvalues={}, + retcols=["stream_id", "room_id"], + desc="get_device_change_last_converted_pos", + ) + return row["stream_id"], row["room_id"] + + async def set_device_change_last_converted_pos( + self, + stream_id: int, + room_id: str, + ) -> None: + """ + Set the position of the last row in `device_list_changes_in_room` that has been + converted to `device_lists_outbound_pokes`. + """ + + await self.db_pool.simple_update_one( + table="device_lists_changes_converted_stream_position", + keyvalues={}, + updatevalues={"stream_id": stream_id, "room_id": room_id}, + desc="set_device_change_last_converted_pos", + ) diff --git a/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql b/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql new file mode 100644 index 0000000000..93d7fcb79b --- /dev/null +++ b/synapse/storage/schema/main/delta/73/12refactor_device_list_outbound_pokes.sql @@ -0,0 +1,53 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Prior to this schema delta, we tracked the set of unconverted rows in +-- `device_lists_changes_in_room` using the `converted_to_destinations` flag. When rows +-- were converted to `device_lists_outbound_pokes`, the `converted_to_destinations` flag +-- would be set. +-- +-- After this schema delta, the `converted_to_destinations` is still populated like +-- before, but the set of unconverted rows is determined by the `stream_id` in the new +-- `device_lists_changes_converted_stream_position` table. +-- +-- If rolled back, Synapse will re-send all device list changes that happened since the +-- schema delta. + +CREATE TABLE IF NOT EXISTS device_lists_changes_converted_stream_position( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + -- The (stream id, room id) of the last row in `device_lists_changes_in_room` that + -- has been converted to `device_lists_outbound_pokes`. Rows with a strictly larger + -- (stream id, room id) where `converted_to_destinations` is `FALSE` have not been + -- converted. + stream_id BIGINT NOT NULL, + -- `room_id` may be an empty string, which compares less than all valid room IDs. + room_id TEXT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO device_lists_changes_converted_stream_position (stream_id, room_id) VALUES ( + ( + SELECT COALESCE( + -- The last converted stream id is the smallest unconverted stream id minus + -- one. + MIN(stream_id) - 1, + -- If there is no unconverted stream id, the last converted stream id is the + -- largest stream id. + -- Otherwise, pick 1, since stream ids start at 2. + (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room) + ) FROM device_lists_changes_in_room WHERE NOT converted_to_destinations + ), + '' +); diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index f37505b6cf..8e7db2c4ec 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -28,7 +28,7 @@ class DeviceStoreTestCase(HomeserverTestCase): """ for device_id in device_ids: - stream_id = self.get_success( + self.get_success( self.store.add_device_change_to_streams( user_id, [device_id], ["!some:room"] ) @@ -39,7 +39,6 @@ class DeviceStoreTestCase(HomeserverTestCase): user_id=user_id, device_id=device_id, room_id="!some:room", - stream_id=stream_id, hosts=[host], context={}, ) -- cgit 1.5.1 From 9af2be192a759c22d189b72cc0a7580cd9de8a37 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 24 Nov 2022 09:09:17 +0000 Subject: Remove legacy Prometheus metrics names. They were deprecated in Synapse v1.69.0 and disabled by default in Synapse v1.71.0. (#14538) --- changelog.d/14538.removal | 1 + docs/upgrade.md | 22 ++ docs/usage/configuration/config_documentation.md | 25 -- synapse/app/_base.py | 16 +- synapse/app/generic_worker.py | 1 - synapse/app/homeserver.py | 1 - synapse/config/metrics.py | 2 - synapse/metrics/__init__.py | 7 +- synapse/metrics/_legacy_exposition.py | 288 ----------------------- synapse/metrics/_twisted_exposition.py | 38 +++ tests/storage/test_event_metrics.py | 7 +- 11 files changed, 70 insertions(+), 338 deletions(-) create mode 100644 changelog.d/14538.removal delete mode 100644 synapse/metrics/_legacy_exposition.py create mode 100644 synapse/metrics/_twisted_exposition.py (limited to 'tests/storage') diff --git a/changelog.d/14538.removal b/changelog.d/14538.removal new file mode 100644 index 0000000000..d2035ce82a --- /dev/null +++ b/changelog.d/14538.removal @@ -0,0 +1 @@ +Remove legacy Prometheus metrics names. They were deprecated in Synapse v1.69.0 and disabled by default in Synapse v1.71.0. \ No newline at end of file diff --git a/docs/upgrade.md b/docs/upgrade.md index 2aa353e496..4fe9e4f02e 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -88,6 +88,28 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.73.0 + +## Legacy Prometheus metric names have now been removed + +Synapse v1.69.0 included the deprecation of legacy Prometheus metric names +and offered an option to disable them. +Synapse v1.71.0 disabled legacy Prometheus metric names by default. + +This version, v1.73.0, removes those legacy Prometheus metric names entirely. +This also means that the `enable_legacy_metrics` configuration option has been +removed; it will no longer be possible to re-enable the legacy metric names. + +If you use metrics and have not yet updated your Grafana dashboard(s), +Prometheus console(s) or alerting rule(s), please consider doing so when upgrading +to this version. +Note that the included Grafana dashboard was updated in v1.72.0 to correct some +metric names which were missed when legacy metrics were disabled by default. + +See [v1.69.0: Deprecation of legacy Prometheus metric names](#deprecation-of-legacy-prometheus-metric-names) +for more context. + + # Upgrading to v1.72.0 ## Dropping support for PostgreSQL 10 diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index f5937dd902..fae2771fad 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2437,31 +2437,6 @@ Example configuration: enable_metrics: true ``` --- -### `enable_legacy_metrics` - -Set to `true` to publish both legacy and non-legacy Prometheus metric names, -or to `false` to only publish non-legacy Prometheus metric names. -Defaults to `false`. Has no effect if `enable_metrics` is `false`. -**In Synapse v1.67.0 up to and including Synapse v1.70.1, this defaulted to `true`.** - -Legacy metric names include: -- metrics containing colons in the name, such as `synapse_util_caches_response_cache:hits`, because colons are supposed to be reserved for user-defined recording rules; -- counters that don't end with the `_total` suffix, such as `synapse_federation_client_sent_edus`, therefore not adhering to the OpenMetrics standard. - -These legacy metric names are unconventional and not compliant with OpenMetrics standards. -They are included for backwards compatibility. - -Example configuration: -```yaml -enable_legacy_metrics: false -``` - -See https://github.com/matrix-org/synapse/issues/11106 for context. - -*Since v1.67.0.* - -**Will be removed in v1.73.0.** ---- ### `sentry` Use this option to enable sentry integration. Provide the DSN assigned to you by sentry diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 41d2732ef9..a5aa2185a2 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -266,26 +266,18 @@ def register_start( reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) -def listen_metrics( - bind_addresses: Iterable[str], port: int, enable_legacy_metric_names: bool -) -> None: +def listen_metrics(bind_addresses: Iterable[str], port: int) -> None: """ Start Prometheus metrics server. """ from prometheus_client import start_http_server as start_http_server_prometheus - from synapse.metrics import ( - RegistryProxy, - start_http_server as start_http_server_legacy, - ) + from synapse.metrics import RegistryProxy for host in bind_addresses: logger.info("Starting metrics listener on %s:%d", host, port) - if enable_legacy_metric_names: - start_http_server_legacy(port, addr=host, registry=RegistryProxy) - else: - _set_prometheus_client_use_created_metrics(False) - start_http_server_prometheus(port, addr=host, registry=RegistryProxy) + _set_prometheus_client_use_created_metrics(False) + start_http_server_prometheus(port, addr=host, registry=RegistryProxy) def _set_prometheus_client_use_created_metrics(new_value: bool) -> None: diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 74909b7d4a..46dc731696 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -320,7 +320,6 @@ class GenericWorkerServer(HomeServer): _base.listen_metrics( listener.bind_addresses, listener.port, - enable_legacy_metric_names=self.config.metrics.enable_legacy_metrics, ) else: logger.warning("Unsupported listener type: %s", listener.type) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 4f4fee4782..b9be558c7e 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -265,7 +265,6 @@ class SynapseHomeServer(HomeServer): _base.listen_metrics( listener.bind_addresses, listener.port, - enable_legacy_metric_names=self.config.metrics.enable_legacy_metrics, ) else: # this shouldn't happen, as the listener type should have been checked diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 6034a0346e..8c1c9bd12d 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -43,8 +43,6 @@ class MetricsConfig(Config): def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.enable_metrics = config.get("enable_metrics", False) - self.enable_legacy_metrics = config.get("enable_legacy_metrics", False) - self.report_stats = config.get("report_stats", None) self.report_stats_endpoint = config.get( "report_stats_endpoint", "https://matrix.org/report-usage-stats/push" diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index c3d3daf877..b01372565d 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -47,11 +47,7 @@ from twisted.python.threadpool import ThreadPool # This module is imported for its side effects; flake8 needn't warn that it's unused. import synapse.metrics._reactor_metrics # noqa: F401 from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager -from synapse.metrics._legacy_exposition import ( - MetricsResource, - generate_latest, - start_http_server, -) +from synapse.metrics._twisted_exposition import MetricsResource, generate_latest from synapse.metrics._types import Collector from synapse.util import SYNAPSE_VERSION @@ -474,7 +470,6 @@ __all__ = [ "Collector", "MetricsResource", "generate_latest", - "start_http_server", "LaterGauge", "InFlightGauge", "GaugeBucketCollector", diff --git a/synapse/metrics/_legacy_exposition.py b/synapse/metrics/_legacy_exposition.py deleted file mode 100644 index 1459f9d224..0000000000 --- a/synapse/metrics/_legacy_exposition.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright 2015-2019 Prometheus Python Client Developers -# Copyright 2019 Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This code is based off `prometheus_client/exposition.py` from version 0.7.1. - -Due to the renaming of metrics in prometheus_client 0.4.0, this customised -vendoring of the code will emit both the old versions that Synapse dashboards -expect, and the newer "best practice" version of the up-to-date official client. -""" -import logging -import math -import threading -from http.server import BaseHTTPRequestHandler, HTTPServer -from socketserver import ThreadingMixIn -from typing import Any, Dict, List, Type, Union -from urllib.parse import parse_qs, urlparse - -from prometheus_client import REGISTRY, CollectorRegistry -from prometheus_client.core import Sample - -from twisted.web.resource import Resource -from twisted.web.server import Request - -logger = logging.getLogger(__name__) -CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" - - -def floatToGoString(d: Union[int, float]) -> str: - d = float(d) - if d == math.inf: - return "+Inf" - elif d == -math.inf: - return "-Inf" - elif math.isnan(d): - return "NaN" - else: - s = repr(d) - dot = s.find(".") - # Go switches to exponents sooner than Python. - # We only need to care about positive values for le/quantile. - if d > 0 and dot > 6: - mantissa = f"{s[0]}.{s[1:dot]}{s[dot + 1 :]}".rstrip("0.") - return f"{mantissa}e+0{dot - 1}" - return s - - -def sample_line(line: Sample, name: str) -> str: - if line.labels: - labelstr = "{{{0}}}".format( - ",".join( - [ - '{}="{}"'.format( - k, - v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""), - ) - for k, v in sorted(line.labels.items()) - ] - ) - ) - else: - labelstr = "" - timestamp = "" - if line.timestamp is not None: - # Convert to milliseconds. - timestamp = f" {int(float(line.timestamp) * 1000):d}" - return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp) - - -# Mapping from new metric names to legacy metric names. -# We translate these back to their old names when exposing them through our -# legacy vendored exporter. -# Only this legacy exposition module applies these name changes. -LEGACY_METRIC_NAMES = { - "synapse_util_caches_cache_hits": "synapse_util_caches_cache:hits", - "synapse_util_caches_cache_size": "synapse_util_caches_cache:size", - "synapse_util_caches_cache_evicted_size": "synapse_util_caches_cache:evicted_size", - "synapse_util_caches_cache": "synapse_util_caches_cache:total", - "synapse_util_caches_response_cache_size": "synapse_util_caches_response_cache:size", - "synapse_util_caches_response_cache_hits": "synapse_util_caches_response_cache:hits", - "synapse_util_caches_response_cache_evicted_size": "synapse_util_caches_response_cache:evicted_size", - "synapse_util_caches_response_cache": "synapse_util_caches_response_cache:total", - "synapse_federation_client_sent_pdu_destinations": "synapse_federation_client_sent_pdu_destinations:total", - "synapse_federation_client_sent_pdu_destinations_count": "synapse_federation_client_sent_pdu_destinations:count", - "synapse_admin_mau_current": "synapse_admin_mau:current", - "synapse_admin_mau_max": "synapse_admin_mau:max", - "synapse_admin_mau_registered_reserved_users": "synapse_admin_mau:registered_reserved_users", -} - - -def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes: - """ - Generate metrics in legacy format. Modern metrics are generated directly - by prometheus-client. - """ - - output = [] - - for metric in registry.collect(): - if not metric.samples: - # No samples, don't bother. - continue - - # Translate to legacy metric name if it has one. - mname = LEGACY_METRIC_NAMES.get(metric.name, metric.name) - mnewname = metric.name - mtype = metric.type - - # OpenMetrics -> Prometheus - if mtype == "counter": - mnewname = mnewname + "_total" - elif mtype == "info": - mtype = "gauge" - mnewname = mnewname + "_info" - elif mtype == "stateset": - mtype = "gauge" - elif mtype == "gaugehistogram": - mtype = "histogram" - elif mtype == "unknown": - mtype = "untyped" - - # Output in the old format for compatibility. - if emit_help: - output.append( - "# HELP {} {}\n".format( - mname, - metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), - ) - ) - output.append(f"# TYPE {mname} {mtype}\n") - - om_samples: Dict[str, List[str]] = {} - for s in metric.samples: - for suffix in ["_created", "_gsum", "_gcount"]: - if s.name == mname + suffix: - # OpenMetrics specific sample, put in a gauge at the end. - # (these come from gaugehistograms which don't get renamed, - # so no need to faff with mnewname) - om_samples.setdefault(suffix, []).append(sample_line(s, s.name)) - break - else: - newname = s.name.replace(mnewname, mname) - if ":" in newname and newname.endswith("_total"): - newname = newname[: -len("_total")] - output.append(sample_line(s, newname)) - - for suffix, lines in sorted(om_samples.items()): - if emit_help: - output.append( - "# HELP {}{} {}\n".format( - mname, - suffix, - metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), - ) - ) - output.append(f"# TYPE {mname}{suffix} gauge\n") - output.extend(lines) - - # Get rid of the weird colon things while we're at it - if mtype == "counter": - mnewname = mnewname.replace(":total", "") - mnewname = mnewname.replace(":", "_") - - if mname == mnewname: - continue - - # Also output in the new format, if it's different. - if emit_help: - output.append( - "# HELP {} {}\n".format( - mnewname, - metric.documentation.replace("\\", r"\\").replace("\n", r"\n"), - ) - ) - output.append(f"# TYPE {mnewname} {mtype}\n") - - for s in metric.samples: - # Get rid of the OpenMetrics specific samples (we should already have - # dealt with them above anyway.) - for suffix in ["_created", "_gsum", "_gcount"]: - if s.name == mname + suffix: - break - else: - sample_name = LEGACY_METRIC_NAMES.get(s.name, s.name) - output.append( - sample_line(s, sample_name.replace(":total", "").replace(":", "_")) - ) - - return "".join(output).encode("utf-8") - - -class MetricsHandler(BaseHTTPRequestHandler): - """HTTP handler that gives metrics from ``REGISTRY``.""" - - registry = REGISTRY - - def do_GET(self) -> None: - registry = self.registry - params = parse_qs(urlparse(self.path).query) - - if "help" in params: - emit_help = True - else: - emit_help = False - - try: - output = generate_latest(registry, emit_help=emit_help) - except Exception: - self.send_error(500, "error generating metric output") - raise - try: - self.send_response(200) - self.send_header("Content-Type", CONTENT_TYPE_LATEST) - self.send_header("Content-Length", str(len(output))) - self.end_headers() - self.wfile.write(output) - except BrokenPipeError as e: - logger.warning( - "BrokenPipeError when serving metrics (%s). Did Prometheus restart?", e - ) - - def log_message(self, format: str, *args: Any) -> None: - """Log nothing.""" - - @classmethod - def factory(cls, registry: CollectorRegistry) -> Type: - """Returns a dynamic MetricsHandler class tied - to the passed registry. - """ - # This implementation relies on MetricsHandler.registry - # (defined above and defaulted to REGISTRY). - - # As we have unicode_literals, we need to create a str() - # object for type(). - cls_name = str(cls.__name__) - MyMetricsHandler = type(cls_name, (cls, object), {"registry": registry}) - return MyMetricsHandler - - -class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer): - """Thread per request HTTP server.""" - - # Make worker threads "fire and forget". Beginning with Python 3.7 this - # prevents a memory leak because ``ThreadingMixIn`` starts to gather all - # non-daemon threads in a list in order to join on them at server close. - # Enabling daemon threads virtually makes ``_ThreadingSimpleServer`` the - # same as Python 3.7's ``ThreadingHTTPServer``. - daemon_threads = True - - -def start_http_server( - port: int, addr: str = "", registry: CollectorRegistry = REGISTRY -) -> None: - """Starts an HTTP server for prometheus metrics as a daemon thread""" - CustomMetricsHandler = MetricsHandler.factory(registry) - httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler) - t = threading.Thread(target=httpd.serve_forever) - t.daemon = True - t.start() - - -class MetricsResource(Resource): - """ - Twisted ``Resource`` that serves prometheus metrics. - """ - - isLeaf = True - - def __init__(self, registry: CollectorRegistry = REGISTRY): - self.registry = registry - - def render_GET(self, request: Request) -> bytes: - request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) - response = generate_latest(self.registry) - request.setHeader(b"Content-Length", str(len(response))) - return response diff --git a/synapse/metrics/_twisted_exposition.py b/synapse/metrics/_twisted_exposition.py new file mode 100644 index 0000000000..0abcd14953 --- /dev/null +++ b/synapse/metrics/_twisted_exposition.py @@ -0,0 +1,38 @@ +# Copyright 2015-2019 Prometheus Python Client Developers +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from prometheus_client import REGISTRY, CollectorRegistry, generate_latest + +from twisted.web.resource import Resource +from twisted.web.server import Request + +CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" + + +class MetricsResource(Resource): + """ + Twisted ``Resource`` that serves prometheus metrics. + """ + + isLeaf = True + + def __init__(self, registry: CollectorRegistry = REGISTRY): + self.registry = registry + + def render_GET(self, request: Request) -> bytes: + request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) + response = generate_latest(self.registry) + request.setHeader(b"Content-Length", str(len(response))) + return response diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 088fbb247b..6f1135eef4 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from prometheus_client import generate_latest -from synapse.metrics import REGISTRY, generate_latest +from synapse.metrics import REGISTRY from synapse.types import UserID, create_requester from tests.unittest import HomeserverTestCase @@ -53,8 +54,8 @@ class ExtremStatisticsTestCase(HomeserverTestCase): items = list( filter( - lambda x: b"synapse_forward_extremities_" in x, - generate_latest(REGISTRY, emit_help=False).split(b"\n"), + lambda x: b"synapse_forward_extremities_" in x and b"# HELP" not in x, + generate_latest(REGISTRY).split(b"\n"), ) ) -- cgit 1.5.1