From ab18441573dc14cea1fe4082b2a89b9d392a4b9f Mon Sep 17 00:00:00 2001 From: Šimon Brandner Date: Fri, 5 Aug 2022 17:09:33 +0200 Subject: Support stable identifiers for MSC2285: private read receipts. (#13273) This adds support for the stable identifiers of MSC2285 while continuing to support the unstable identifiers behind the configuration flag. These will be removed in a future version. --- synapse/replication/tcp/client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'synapse/replication/tcp/client.py') diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e4f2201c92..1ed7230e32 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -416,7 +416,10 @@ class FederationSenderHandler: if not self._is_mine_id(receipt.user_id): continue # Private read receipts never get sent over federation. - if receipt.receipt_type == ReceiptTypes.READ_PRIVATE: + if receipt.receipt_type in ( + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ): continue receipt_info = ReadReceipt( receipt.room_id, -- cgit 1.5.1 From 0e99f07952edcb6396654e34da50ddeb0a211067 Mon Sep 17 00:00:00 2001 From: Šimon Brandner Date: Thu, 1 Sep 2022 14:31:54 +0200 Subject: Remove support for unstable private read receipts (#13653) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- changelog.d/13653.removal | 1 + synapse/api/constants.py | 1 - synapse/config/experimental.py | 3 -- synapse/handlers/receipts.py | 29 +++---------- synapse/replication/tcp/client.py | 5 +-- synapse/rest/client/notifications.py | 1 - synapse/rest/client/read_marker.py | 2 - synapse/rest/client/receipts.py | 2 - synapse/rest/client/versions.py | 1 - .../storage/databases/main/event_push_actions.py | 2 - tests/handlers/test_receipts.py | 48 ++++++---------------- tests/rest/client/test_sync.py | 37 +++++------------ tests/storage/test_receipts.py | 34 ++++++--------- 13 files changed, 44 insertions(+), 122 deletions(-) create mode 100644 changelog.d/13653.removal (limited to 'synapse/replication/tcp/client.py') diff --git a/changelog.d/13653.removal b/changelog.d/13653.removal new file mode 100644 index 0000000000..eb075d4517 --- /dev/null +++ b/changelog.d/13653.removal @@ -0,0 +1 @@ +Remove support for unstable [private read receipts](https://github.com/matrix-org/matrix-spec-proposals/pull/2285). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index c73aea622a..c178ddf070 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -258,7 +258,6 @@ class GuestAccess: class ReceiptTypes: READ: Final = "m.read" READ_PRIVATE: Final = "m.read.private" - UNSTABLE_READ_PRIVATE: Final = "org.matrix.msc2285.read.private" FULLY_READ: Final = "m.fully_read" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index c1ff417539..260db49cad 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -32,9 +32,6 @@ class ExperimentalConfig(Config): # MSC2716 (importing historical messages) self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) - # MSC2285 (unstable private read receipts) - self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False) - # MSC3244 (room version capabilities) self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index d4a866b346..d2bdb9c8be 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -163,10 +163,7 @@ class ReceiptsHandler: if not is_new: return - if self.federation_sender and receipt_type not in ( - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ): + if self.federation_sender and receipt_type != ReceiptTypes.READ_PRIVATE: await self.federation_sender.send_read_receipt(receipt) @@ -206,38 +203,24 @@ class ReceiptEventSource(EventSource[int, JsonDict]): for event_id, orig_event_content in room.get("content", {}).items(): event_content = orig_event_content # If there are private read receipts, additional logic is necessary. - if ( - ReceiptTypes.READ_PRIVATE in event_content - or ReceiptTypes.UNSTABLE_READ_PRIVATE in event_content - ): + if ReceiptTypes.READ_PRIVATE in event_content: # Make a copy without private read receipts to avoid leaking # other user's private read receipts.. event_content = { receipt_type: receipt_value for receipt_type, receipt_value in event_content.items() - if receipt_type - not in ( - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ) + if receipt_type != ReceiptTypes.READ_PRIVATE } # Copy the current user's private read receipt from the # original content, if it exists. - user_private_read_receipt = orig_event_content.get( - ReceiptTypes.READ_PRIVATE, {} - ).get(user_id, None) + user_private_read_receipt = orig_event_content[ + ReceiptTypes.READ_PRIVATE + ].get(user_id, None) if user_private_read_receipt: event_content[ReceiptTypes.READ_PRIVATE] = { user_id: user_private_read_receipt } - user_unstable_private_read_receipt = orig_event_content.get( - ReceiptTypes.UNSTABLE_READ_PRIVATE, {} - ).get(user_id, None) - if user_unstable_private_read_receipt: - event_content[ReceiptTypes.UNSTABLE_READ_PRIVATE] = { - user_id: user_unstable_private_read_receipt - } # Include the event if there is at least one non-private read # receipt or the current user has a private read receipt. diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 1ed7230e32..e4f2201c92 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -416,10 +416,7 @@ class FederationSenderHandler: if not self._is_mine_id(receipt.user_id): continue # Private read receipts never get sent over federation. - if receipt.receipt_type in ( - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ): + if receipt.receipt_type == ReceiptTypes.READ_PRIVATE: continue receipt_info = ReadReceipt( receipt.room_id, diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index a73322a6a4..61268e3af1 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -62,7 +62,6 @@ class NotificationsServlet(RestServlet): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ], ) diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index aaad8b233f..5e53096539 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -45,8 +45,6 @@ class ReadMarkerRestServlet(RestServlet): ReceiptTypes.FULLY_READ, ReceiptTypes.READ_PRIVATE, } - if hs.config.experimental.msc2285_enabled: - self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE) async def on_POST( self, request: SynapseRequest, room_id: str diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index c6108fc5eb..5b7fad7402 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -49,8 +49,6 @@ class ReceiptRestServlet(RestServlet): ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ, } - if hs.config.experimental.msc2285_enabled: - self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE) async def on_POST( self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c9a830cbac..c516cda95d 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -95,7 +95,6 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, # Supports receiving private read receipts as per MSC2285 "org.matrix.msc2285.stable": True, # TODO: Remove when MSC2285 becomes a part of the spec - "org.matrix.msc2285": self.config.experimental.msc2285_enabled, # Supports filtering of /publicRooms by room type as per MSC3827 "org.matrix.msc3827.stable": True, # Adds support for importing historical messages as per MSC2716 diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 9f410d69de..f4a07de2a3 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -274,7 +274,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas receipt_types=( ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ), ) @@ -468,7 +467,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ( ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ), ) diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 5f70a2db79..b55238650c 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -15,8 +15,6 @@ from copy import deepcopy from typing import List -from parameterized import parameterized - from synapse.api.constants import EduTypes, ReceiptTypes from synapse.types import JsonDict @@ -27,16 +25,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.event_source = hs.get_event_sources().sources.receipt - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_filters_out_private_receipt(self, receipt_type: str) -> None: + def test_filters_out_private_receipt(self) -> None: self._test_filters_private( [ { "content": { "$1435641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, } @@ -50,18 +45,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): [], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_filters_out_private_receipt_and_ignores_rest( - self, receipt_type: str - ) -> None: + def test_filters_out_private_receipt_and_ignores_rest(self) -> None: self._test_filters_private( [ { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -94,18 +84,15 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest( - self, receipt_type: str + self, ) -> None: self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -175,18 +162,15 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest( - self, receipt_type: str + self, ) -> None: self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -262,16 +246,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_leaves_our_private_and_their_public(self, receipt_type: str) -> None: + def test_leaves_our_private_and_their_public(self) -> None: self._test_filters_private( [ { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@me:server.org": { "ts": 1436451550453, }, @@ -296,7 +277,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@me:server.org": { "ts": 1436451550453, }, @@ -319,16 +300,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_we_do_not_mutate(self, receipt_type: str) -> None: + def test_we_do_not_mutate(self) -> None: """Ensure the input values are not modified.""" events = [ { "content": { "$1435641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, } diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index de0dec8539..0af643ecd9 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -391,7 +391,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() - config["experimental_features"] = {"msc2285_enabled": True} return self.setup_test_homeserver(config=config) @@ -413,17 +412,14 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Join the second user self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_private_read_receipts(self, receipt_type: str) -> None: + def test_private_read_receipts(self) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) # Send a private read receipt to tell the server the first user's message was read channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -432,10 +428,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that the first user can't see the other user's private read receipt self.assertIsNone(self._get_read_receipt()) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_public_receipt_can_override_private(self, receipt_type: str) -> None: + def test_public_receipt_can_override_private(self) -> None: """ Sending a public read receipt to the same event which has a private read receipt should cause that receipt to become public. @@ -446,7 +439,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Send a private read receipt channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -465,10 +458,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that we did override the private read receipt self.assertNotEqual(self._get_read_receipt(), None) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_private_receipt_cannot_override_public(self, receipt_type: str) -> None: + def test_private_receipt_cannot_override_public(self) -> None: """ Sending a private read receipt to the same event which has a public read receipt should cause no change. @@ -489,7 +479,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Send a private read receipt channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -554,7 +544,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): config = super().default_config() config["experimental_features"] = { "msc2654_enabled": True, - "msc2285_enabled": True, } return config @@ -601,10 +590,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): tok=self.tok, ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_unread_counts(self, receipt_type: str) -> None: + def test_unread_counts(self) -> None: """Tests that /sync returns the right value for the unread count (MSC2654).""" # Check that our own messages don't increase the unread count. @@ -638,7 +624,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Send a read receipt to tell the server we've read the latest event. channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok, ) @@ -726,7 +712,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}", {}, access_token=self.tok, ) @@ -738,7 +724,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ] ) def test_read_receipts_only_go_down(self, receipt_type: str) -> None: @@ -752,7 +737,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Read last event channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}", {}, access_token=self.tok, ) @@ -763,7 +748,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # read receipt go up to an older event channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res1['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res1['event_id']}", {}, access_token=self.tok, ) diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index 191c957fb5..c89bfff241 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from parameterized import parameterized from synapse.api.constants import ReceiptTypes from synapse.types import UserID, create_requester @@ -92,7 +91,6 @@ class ReceiptTestCase(HomeserverTestCase): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ], ) ) @@ -104,7 +102,6 @@ class ReceiptTestCase(HomeserverTestCase): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ], ) ) @@ -117,16 +114,12 @@ class ReceiptTestCase(HomeserverTestCase): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ], ) ) self.assertEqual(res, None) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_get_receipts_for_user(self, receipt_type: str) -> None: + def test_get_receipts_for_user(self) -> None: # Send some events into the first room event1_1_id = self.create_and_send_event( self.room_id1, UserID.from_string(OTHER_USER_ID) @@ -144,14 +137,14 @@ class ReceiptTestCase(HomeserverTestCase): # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} ) ) # Test we get the latest event when we want both private and public receipts res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, receipt_type] + OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) ) self.assertEqual(res, {self.room_id1: event1_2_id}) @@ -164,7 +157,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test we get the latest event when we want only the public receipt res = self.get_success( - self.store.get_receipts_for_user(OUR_USER_ID, [receipt_type]) + self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ_PRIVATE]) ) self.assertEqual(res, {self.room_id1: event1_2_id}) @@ -187,20 +180,17 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {} + self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, receipt_type] + OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) ) self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id}) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_get_last_receipt_event_id_for_user(self, receipt_type: str) -> None: + def test_get_last_receipt_event_id_for_user(self) -> None: # Send some events into the first room event1_1_id = self.create_and_send_event( self.room_id1, UserID.from_string(OTHER_USER_ID) @@ -218,7 +208,7 @@ class ReceiptTestCase(HomeserverTestCase): # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} ) ) @@ -227,7 +217,7 @@ class ReceiptTestCase(HomeserverTestCase): self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, - [ReceiptTypes.READ, receipt_type], + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], ) ) self.assertEqual(res, event1_2_id) @@ -243,7 +233,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test we get the latest event when we want only the private receipt res = self.get_success( self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [receipt_type] + OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] ) ) self.assertEqual(res, event1_2_id) @@ -269,14 +259,14 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {} + self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id2, - [ReceiptTypes.READ, receipt_type], + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], ) ) self.assertEqual(res, event2_1_id) -- cgit 1.5.1 From 8ae42ab8fa3c6b52d74c24daa7ca75a478fa4fbb Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 21 Sep 2022 15:39:01 +0100 Subject: Support enabling/disabling pushers (from MSC3881) (#13799) Partial implementation of MSC3881 --- changelog.d/13799.feature | 1 + synapse/_scripts/synapse_port_db.py | 1 + synapse/config/experimental.py | 3 + synapse/handlers/register.py | 4 +- synapse/push/__init__.py | 2 + synapse/push/pusherpool.py | 81 ++++++++--- synapse/replication/tcp/client.py | 10 +- synapse/rest/admin/users.py | 4 +- synapse/rest/client/pusher.py | 18 ++- synapse/storage/databases/main/pusher.py | 69 ++++++---- .../schema/main/delta/73/02add_pusher_enabled.sql | 16 +++ tests/push/test_email.py | 4 +- tests/push/test_http.py | 148 +++++++++++++++++++-- tests/replication/test_pusher_shard.py | 2 +- tests/rest/admin/test_user.py | 2 +- 15 files changed, 294 insertions(+), 71 deletions(-) create mode 100644 changelog.d/13799.feature create mode 100644 synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql (limited to 'synapse/replication/tcp/client.py') diff --git a/changelog.d/13799.feature b/changelog.d/13799.feature new file mode 100644 index 0000000000..6c8e5cffe2 --- /dev/null +++ b/changelog.d/13799.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881). diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 30983c47fb..450ba462ba 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -111,6 +111,7 @@ BOOLEAN_COLUMNS = { "e2e_fallback_keys_json": ["used"], "access_tokens": ["used"], "device_lists_changes_in_room": ["converted_to_destinations"], + "pushers": ["enabled"], } diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 702b81e636..f4541a8db0 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -93,3 +93,6 @@ class ExperimentalConfig(Config): # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) + + # MSC3881: Remotely toggle push notifications for another client + self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 20ec22105a..cfcadb34db 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -997,7 +997,7 @@ class RegistrationHandler: assert user_tuple token_id = user_tuple.token_id - await self.pusher_pool.add_pusher( + await self.pusher_pool.add_or_update_pusher( user_id=user_id, access_token=token_id, kind="email", @@ -1005,7 +1005,7 @@ class RegistrationHandler: app_display_name="Email Notifications", device_display_name=threepid["address"], pushkey=threepid["address"], - lang=None, # We don't know a user's language here + lang=None, data={}, ) diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 57c4d70466..ac99d35a7e 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -116,6 +116,7 @@ class PusherConfig: last_stream_ordering: int last_success: Optional[int] failing_since: Optional[int] + enabled: bool def as_dict(self) -> Dict[str, Any]: """Information that can be retrieved about a pusher after creation.""" @@ -128,6 +129,7 @@ class PusherConfig: "lang": self.lang, "profile_tag": self.profile_tag, "pushkey": self.pushkey, + "enabled": self.enabled, } diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 1e0ef44fc7..2597898cf4 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -94,7 +94,7 @@ class PusherPool: return run_as_background_process("start_pushers", self._start_pushers) - async def add_pusher( + async def add_or_update_pusher( self, user_id: str, access_token: Optional[int], @@ -106,6 +106,7 @@ class PusherPool: lang: Optional[str], data: JsonDict, profile_tag: str = "", + enabled: bool = True, ) -> Optional[Pusher]: """Creates a new pusher and adds it to the pool @@ -147,9 +148,20 @@ class PusherPool: last_stream_ordering=last_stream_ordering, last_success=None, failing_since=None, + enabled=enabled, ) ) + # Before we actually persist the pusher, we check if the user already has one + # for this app ID and pushkey. If so, we want to keep the access token in place, + # since this could be one device modifying (e.g. enabling/disabling) another + # device's pusher. + existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey( + user_id, app_id, pushkey + ) + if existing_config: + access_token = existing_config.access_token + await self.store.add_pusher( user_id=user_id, access_token=access_token, @@ -163,8 +175,9 @@ class PusherPool: data=data, last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, + enabled=enabled, ) - pusher = await self.start_pusher_by_id(app_id, pushkey, user_id) + pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id) return pusher @@ -276,10 +289,25 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_receipts") - async def start_pusher_by_id( + async def _get_pusher_config_for_user_by_app_id_and_pushkey( + self, user_id: str, app_id: str, pushkey: str + ) -> Optional[PusherConfig]: + resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + + pusher_config = None + for r in resultlist: + if r.user_name == user_id: + pusher_config = r + + return pusher_config + + async def process_pusher_change_by_id( self, app_id: str, pushkey: str, user_id: str ) -> Optional[Pusher]: - """Look up the details for the given pusher, and start it + """Look up the details for the given pusher, and either start it if its + "enabled" flag is True, or try to stop it otherwise. + + If the pusher is new and its "enabled" flag is False, the stop is a noop. Returns: The pusher started, if any @@ -290,12 +318,13 @@ class PusherPool: if not self._pusher_shard_config.should_handle(self._instance_name, user_id): return None - resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + pusher_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey( + user_id, app_id, pushkey + ) - pusher_config = None - for r in resultlist: - if r.user_name == user_id: - pusher_config = r + if pusher_config and not pusher_config.enabled: + self.maybe_stop_pusher(app_id, pushkey, user_id) + return None pusher = None if pusher_config: @@ -305,7 +334,7 @@ class PusherPool: async def _start_pushers(self) -> None: """Start all the pushers""" - pushers = await self.store.get_all_pushers() + pushers = await self.store.get_enabled_pushers() # Stagger starting up the pushers so we don't completely drown the # process on start up. @@ -363,6 +392,8 @@ class PusherPool: synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc() + logger.info("Starting pusher %s / %s", pusher.user_id, appid_pushkey) + # Check if there *may* be push to process. We do this as this check is a # lot cheaper to do than actually fetching the exact rows we need to # push. @@ -382,16 +413,7 @@ class PusherPool: return pusher async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None: - appid_pushkey = "%s:%s" % (app_id, pushkey) - - byuser = self.pushers.get(user_id, {}) - - if appid_pushkey in byuser: - logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) - pusher = byuser.pop(appid_pushkey) - pusher.on_stop() - - synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() + self.maybe_stop_pusher(app_id, pushkey, user_id) # We can only delete pushers on master. if self._remove_pusher_client: @@ -402,3 +424,22 @@ class PusherPool: await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id ) + + def maybe_stop_pusher(self, app_id: str, pushkey: str, user_id: str) -> None: + """Stops a pusher with the given app ID and push key if one is running. + + Args: + app_id: the pusher's app ID. + pushkey: the pusher's push key. + user_id: the user the pusher belongs to. Only used for logging. + """ + appid_pushkey = "%s:%s" % (app_id, pushkey) + + byuser = self.pushers.get(user_id, {}) + + if appid_pushkey in byuser: + logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) + pusher = byuser.pop(appid_pushkey) + pusher.on_stop() + + synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e4f2201c92..cf9cd6833b 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -189,7 +189,9 @@ class ReplicationDataHandler: if row.deleted: self.stop_pusher(row.user_id, row.app_id, row.pushkey) else: - await self.start_pusher(row.user_id, row.app_id, row.pushkey) + await self.process_pusher_change( + row.user_id, row.app_id, row.pushkey + ) elif stream_name == EventsStream.NAME: # We shouldn't get multiple rows per token for events stream, so # we don't need to optimise this for multiple rows. @@ -334,13 +336,15 @@ class ReplicationDataHandler: logger.info("Stopping pusher %r / %r", user_id, key) pusher.on_stop() - async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None: + async def process_pusher_change( + self, user_id: str, app_id: str, pushkey: str + ) -> None: if not self._notify_pushers: return key = "%s:%s" % (app_id, pushkey) logger.info("Starting pusher %r / %r", user_id, key) - await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) + await self._pusher_pool.process_pusher_change_by_id(app_id, pushkey, user_id) class FederationSenderHandler: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2ca6b2d08a..1274773d7e 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -375,7 +375,7 @@ class UserRestServletV2(RestServlet): and self.hs.config.email.email_notif_for_new_users and medium == "email" ): - await self.pusher_pool.add_pusher( + await self.pusher_pool.add_or_update_pusher( user_id=user_id, access_token=None, kind="email", @@ -383,7 +383,7 @@ class UserRestServletV2(RestServlet): app_display_name="Email Notifications", device_display_name=address, pushkey=address, - lang=None, # We don't know a user's language here + lang=None, data={}, ) diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index 9a1f10f4be..c9f76125dc 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -42,6 +42,7 @@ class PushersRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() + self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -51,9 +52,14 @@ class PushersRestServlet(RestServlet): user.to_string() ) - filtered_pushers = [p.as_dict() for p in pushers] + pusher_dicts = [p.as_dict() for p in pushers] - return 200, {"pushers": filtered_pushers} + for pusher in pusher_dicts: + if self._msc3881_enabled: + pusher["org.matrix.msc3881.enabled"] = pusher["enabled"] + del pusher["enabled"] + + return 200, {"pushers": pusher_dicts} class PushersSetRestServlet(RestServlet): @@ -65,6 +71,7 @@ class PushersSetRestServlet(RestServlet): self.auth = hs.get_auth() self.notifier = hs.get_notifier() self.pusher_pool = self.hs.get_pusherpool() + self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -103,6 +110,10 @@ class PushersSetRestServlet(RestServlet): if "append" in content: append = content["append"] + enabled = True + if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content: + enabled = content["org.matrix.msc3881.enabled"] + if not append: await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( app_id=content["app_id"], @@ -111,7 +122,7 @@ class PushersSetRestServlet(RestServlet): ) try: - await self.pusher_pool.add_pusher( + await self.pusher_pool.add_or_update_pusher( user_id=user.to_string(), access_token=requester.access_token_id, kind=content["kind"], @@ -122,6 +133,7 @@ class PushersSetRestServlet(RestServlet): lang=content["lang"], data=content["data"], profile_tag=content.get("profile_tag", ""), + enabled=enabled, ) except PusherConfigException as pce: raise SynapseError( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index bd0cfa7f32..ee55b8c4a9 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -89,6 +89,11 @@ class PusherWorkerStore(SQLBaseStore): ) continue + # If we're using SQLite, then boolean values are integers. This is + # troublesome since some code using the return value of this method might + # expect it to be a boolean, or will expose it to clients (in responses). + r["enabled"] = bool(r["enabled"]) + yield PusherConfig(**r) async def get_pushers_by_app_id_and_pushkey( @@ -100,38 +105,52 @@ class PusherWorkerStore(SQLBaseStore): return await self.get_pushers_by({"user_name": user_id}) async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]: - ret = await self.db_pool.simple_select_list( - "pushers", - keyvalues, - [ - "id", - "user_name", - "access_token", - "profile_tag", - "kind", - "app_id", - "app_display_name", - "device_display_name", - "pushkey", - "ts", - "lang", - "data", - "last_stream_ordering", - "last_success", - "failing_since", - ], + """Retrieve pushers that match the given criteria. + + Args: + keyvalues: A {column: value} dictionary. + + Returns: + The pushers for which the given columns have the given values. + """ + + def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]: + # We could technically use simple_select_list here, but we need to call + # COALESCE on the 'enabled' column. While it is technically possible to give + # simple_select_list the whole `COALESCE(...) AS ...` as a column name, it + # feels a bit hacky, so it's probably better to just inline the query. + sql = """ + SELECT + id, user_name, access_token, profile_tag, kind, app_id, + app_display_name, device_display_name, pushkey, ts, lang, data, + last_stream_ordering, last_success, failing_since, + COALESCE(enabled, TRUE) AS enabled + FROM pushers + """ + + sql += "WHERE %s" % (" AND ".join("%s = ?" % (k,) for k in keyvalues),) + + txn.execute(sql, list(keyvalues.values())) + + return self.db_pool.cursor_to_dict(txn) + + ret = await self.db_pool.runInteraction( desc="get_pushers_by", + func=get_pushers_by_txn, ) + return self._decode_pushers_rows(ret) - async def get_all_pushers(self) -> Iterator[PusherConfig]: - def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]: - txn.execute("SELECT * FROM pushers") + async def get_enabled_pushers(self) -> Iterator[PusherConfig]: + def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]: + txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)") rows = self.db_pool.cursor_to_dict(txn) return self._decode_pushers_rows(rows) - return await self.db_pool.runInteraction("get_all_pushers", get_pushers) + return await self.db_pool.runInteraction( + "get_enabled_pushers", get_enabled_pushers_txn + ) async def get_all_updated_pushers_rows( self, instance_name: str, last_id: int, current_id: int, limit: int @@ -476,6 +495,7 @@ class PusherStore(PusherWorkerStore): data: Optional[JsonDict], last_stream_ordering: int, profile_tag: str = "", + enabled: bool = True, ) -> None: async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on @@ -494,6 +514,7 @@ class PusherStore(PusherWorkerStore): "last_stream_ordering": last_stream_ordering, "profile_tag": profile_tag, "id": stream_id, + "enabled": enabled, }, desc="add_pusher", lock=False, diff --git a/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql b/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql new file mode 100644 index 0000000000..dba3b4900b --- /dev/null +++ b/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql @@ -0,0 +1,16 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE pushers ADD COLUMN enabled BOOLEAN; \ No newline at end of file diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 7a3b0d6755..fd14568f55 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase): ) self.pusher = self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.user_id, access_token=self.token_id, kind="email", @@ -136,7 +136,7 @@ class EmailPusherTests(HomeserverTestCase): """ with self.assertRaises(SynapseError) as cm: self.get_success_or_raise( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.user_id, access_token=self.token_id, kind="email", diff --git a/tests/push/test_http.py b/tests/push/test_http.py index d9c68cdd2d..af67d84463 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -19,8 +19,8 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable -from synapse.push import PusherConfigException -from synapse.rest.client import login, push_rule, receipts, room +from synapse.push import PusherConfig, PusherConfigException +from synapse.rest.client import login, push_rule, pusher, receipts, room from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -35,6 +35,7 @@ class HTTPPusherTests(HomeserverTestCase): login.register_servlets, receipts.register_servlets, push_rule.register_servlets, + pusher.register_servlets, ] user_id = True hijack_auth = False @@ -74,7 +75,7 @@ class HTTPPusherTests(HomeserverTestCase): def test_data(data: Optional[JsonDict]) -> None: self.get_failure( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -119,7 +120,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -235,7 +236,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -355,7 +356,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -441,7 +442,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -518,7 +519,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -624,7 +625,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -728,18 +729,38 @@ class HTTPPusherTests(HomeserverTestCase): ) self.assertEqual(channel.code, 200, channel.json_body) - def _make_user_with_pusher(self, username: str) -> Tuple[str, str]: + def _make_user_with_pusher( + self, username: str, enabled: bool = True + ) -> Tuple[str, str]: + """Registers a user and creates a pusher for them. + + Args: + username: the localpart of the new user's Matrix ID. + enabled: whether to create the pusher in an enabled or disabled state. + """ user_id = self.register_user(username, "pass") access_token = self.login(username, "pass") # Register the pusher + self._set_pusher(user_id, access_token, enabled) + + return user_id, access_token + + def _set_pusher(self, user_id: str, access_token: str, enabled: bool) -> None: + """Creates or updates the pusher for the given user. + + Args: + user_id: the user's Matrix ID. + access_token: the access token associated with the pusher. + enabled: whether to enable or disable the pusher. + """ user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -749,11 +770,10 @@ class HTTPPusherTests(HomeserverTestCase): pushkey="a@example.com", lang=None, data={"url": "http://example.com/_matrix/push/v1/notify"}, + enabled=enabled, ) ) - return user_id, access_token - def test_dont_notify_rule_overrides_message(self) -> None: """ The override push rule will suppress notification @@ -791,3 +811,105 @@ class HTTPPusherTests(HomeserverTestCase): # The user sends a message back (sends a notification) self.helper.send(room, body="Hello", tok=access_token) self.assertEqual(len(self.push_attempts), 1) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_disable(self) -> None: + """Tests that disabling a pusher means it's not pushed to anymore.""" + user_id, access_token = self._make_user_with_pusher("user") + other_user_id, other_access_token = self._make_user_with_pusher("otheruser") + + room = self.helper.create_room_as(user_id, tok=access_token) + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Send a message and check that it generated a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Disable the pusher. + self._set_pusher(user_id, access_token, enabled=False) + + # Send another message and check that it did not generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Get the pushers for the user and check that it is marked as disabled. + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + + enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"] + self.assertFalse(enabled) + self.assertTrue(isinstance(enabled, bool)) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_enable(self) -> None: + """Tests that enabling a disabled pusher means it gets pushed to.""" + # Create the user with the pusher already disabled. + user_id, access_token = self._make_user_with_pusher("user", enabled=False) + other_user_id, other_access_token = self._make_user_with_pusher("otheruser") + + room = self.helper.create_room_as(user_id, tok=access_token) + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Send a message and check that it did not generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 0) + + # Enable the pusher. + self._set_pusher(user_id, access_token, enabled=True) + + # Send another message and check that it did generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Get the pushers for the user and check that it is marked as enabled. + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + + enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"] + self.assertTrue(enabled) + self.assertTrue(isinstance(enabled, bool)) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_null_enabled(self) -> None: + """Tests that a pusher that has an 'enabled' column set to NULL (eg pushers + created before the column was introduced) is considered enabled. + """ + # We intentionally set 'enabled' to None so that it's stored as NULL in the + # database. + user_id, access_token = self._make_user_with_pusher("user", enabled=None) # type: ignore[arg-type] + + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]) + + def test_update_different_device_access_token(self) -> None: + """Tests that if we create a pusher from one device, the update it from another + device, the access token associated with the pusher stays the same. + """ + # Create a user with a pusher. + user_id, access_token = self._make_user_with_pusher("user") + + # Get the token ID for the current access token, since that's what we store in + # the pushers table. + user_tuple = self.get_success( + self.hs.get_datastores().main.get_user_by_access_token(access_token) + ) + token_id = user_tuple.token_id + + # Generate a new access token, and update the pusher with it. + new_token = self.login("user", "pass") + self._set_pusher(user_id, new_token, enabled=False) + + # Get the current list of pushers for the user. + ret = self.get_success( + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) + ) + pushers: List[PusherConfig] = list(ret) + + # Check that we still have one pusher, and that the access token associated with + # it didn't change. + self.assertEqual(len(pushers), 1) + self.assertEqual(pushers[0].access_token, token_id) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 8f4f6688ce..59fea93e49 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): token_id = user_dict.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 9f536ceeb3..1847e6ad6b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2839,7 +2839,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.other_user, access_token=token_id, kind="http", -- cgit 1.5.1 From efd108b45d1706526416bc9a6f89463b5ff4506a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 23 Sep 2022 10:33:28 -0400 Subject: Accept & store thread IDs for receipts (implement MSC3771). (#13782) Updates the `/receipts` endpoint and receipt EDU handler to parse a `thread_id` from the body and insert it in the database. --- changelog.d/13782.feature | 1 + synapse/config/experimental.py | 2 + synapse/handlers/receipts.py | 23 ++++++- synapse/replication/tcp/client.py | 3 +- synapse/replication/tcp/streams/_base.py | 1 + synapse/rest/client/read_marker.py | 2 + synapse/rest/client/receipts.py | 14 ++++- synapse/rest/client/versions.py | 2 + synapse/storage/database.py | 2 + synapse/storage/databases/main/receipts.py | 87 +++++++++++++++++++------- synapse/types.py | 1 + tests/federation/test_federation_sender.py | 21 ++++++- tests/handlers/test_appservice.py | 1 + tests/replication/slave/storage/test_events.py | 2 +- tests/replication/tcp/streams/test_receipts.py | 15 ++++- tests/storage/test_event_push_actions.py | 1 + tests/storage/test_receipts.py | 36 ++++++++--- 17 files changed, 173 insertions(+), 41 deletions(-) create mode 100644 changelog.d/13782.feature (limited to 'synapse/replication/tcp/client.py') diff --git a/changelog.d/13782.feature b/changelog.d/13782.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13782.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 595eb007a5..933779c23a 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -83,6 +83,8 @@ class ExperimentalConfig(Config): # MSC3786 (Add a default push rule to ignore m.room.server_acl events) self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False) + # MSC3771: Thread read receipts + self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False) # MSC3772: A push rule for mutual relations. self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index afaf3261df..4768a34c07 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -63,6 +63,8 @@ class ReceiptsHandler: self.clock = self.hs.get_clock() self.state = hs.get_state_handler() + self._msc3771_enabled = hs.config.experimental.msc3771_enabled + async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None: """Called when we receive an EDU of type m.receipt from a remote HS.""" receipts = [] @@ -91,13 +93,23 @@ class ReceiptsHandler: ) continue + # Check if these receipts apply to a thread. + thread_id = None + data = user_values.get("data", {}) + if self._msc3771_enabled and isinstance(data, dict): + thread_id = data.get("thread_id") + # If the thread ID is invalid, consider it missing. + if not isinstance(thread_id, str): + thread_id = None + receipts.append( ReadReceipt( room_id=room_id, receipt_type=receipt_type, user_id=user_id, event_ids=user_values["event_ids"], - data=user_values.get("data", {}), + thread_id=thread_id, + data=data, ) ) @@ -114,6 +126,7 @@ class ReceiptsHandler: receipt.receipt_type, receipt.user_id, receipt.event_ids, + receipt.thread_id, receipt.data, ) @@ -146,7 +159,12 @@ class ReceiptsHandler: return True async def received_client_receipt( - self, room_id: str, receipt_type: str, user_id: str, event_id: str + self, + room_id: str, + receipt_type: str, + user_id: str, + event_id: str, + thread_id: Optional[str], ) -> None: """Called when a client tells us a local user has read up to the given event_id in the room. @@ -156,6 +174,7 @@ class ReceiptsHandler: receipt_type=receipt_type, user_id=user_id, event_ids=[event_id], + thread_id=thread_id, data={"ts": int(self.clock.time_msec())}, ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index cf9cd6833b..b2522f98ca 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -427,7 +427,8 @@ class FederationSenderHandler: receipt.receipt_type, receipt.user_id, [receipt.event_id], - receipt.data, + thread_id=receipt.thread_id, + data=receipt.data, ) await self.federation_sender.send_read_receipt(receipt_info) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 398bebeaa6..e01155ad59 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -361,6 +361,7 @@ class ReceiptsStream(Stream): receipt_type: str user_id: str event_id: str + thread_id: Optional[str] data: dict NAME = "receipts" diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 5e53096539..852838515c 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -83,6 +83,8 @@ class ReadMarkerRestServlet(RestServlet): receipt_type, user_id=requester.user.to_string(), event_id=event_id, + # Setting the thread ID is not possible with the /read_markers endpoint. + thread_id=None, ) return 200, {} diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 5b7fad7402..f3ff156abe 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -49,6 +49,7 @@ class ReceiptRestServlet(RestServlet): ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ, } + self._msc3771_enabled = hs.config.experimental.msc3771_enabled async def on_POST( self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str @@ -61,7 +62,17 @@ class ReceiptRestServlet(RestServlet): f"Receipt type must be {', '.join(self._known_receipt_types)}", ) - parse_json_object_from_request(request, allow_empty_body=False) + body = parse_json_object_from_request(request) + + # Pull the thread ID, if one exists. + thread_id = None + if self._msc3771_enabled: + if "thread_id" in body: + thread_id = body.get("thread_id") + if not thread_id or not isinstance(thread_id, str): + raise SynapseError( + 400, "thread_id field must be a non-empty string" + ) await self.presence_handler.bump_presence_active_time(requester.user) @@ -77,6 +88,7 @@ class ReceiptRestServlet(RestServlet): receipt_type, user_id=requester.user.to_string(), event_id=event_id, + thread_id=thread_id, ) return 200, {} diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index b3917a5abc..c95b0d6f19 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -103,6 +103,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above + # Support for thread read receipts. + "org.matrix.msc3771": self.config.experimental.msc3771_enabled, # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, # Adds support for login token requests as per MSC3882 diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 921cd4dc5e..9d116f6925 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -95,6 +95,8 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx", "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx", "event_push_summary": "event_push_summary_unique_index", + "receipts_linearized": "receipts_linearized_unique_index", + "receipts_graph": "receipts_graph_unique_index", } diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index ddb8e80b69..52fe0db924 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -540,7 +540,9 @@ class ReceiptsWorkerStore(SQLBaseStore): async def get_all_updated_receipts( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, list]], int, bool]: + ) -> Tuple[ + List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], int, bool + ]: """Get updates for receipts replication stream. Args: @@ -567,9 +569,13 @@ class ReceiptsWorkerStore(SQLBaseStore): def get_all_updated_receipts_txn( txn: LoggingTransaction, - ) -> Tuple[List[Tuple[int, list]], int, bool]: + ) -> Tuple[ + List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], + int, + bool, + ]: sql = """ - SELECT stream_id, room_id, receipt_type, user_id, event_id, data + SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC @@ -578,8 +584,8 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id, limit)) updates = cast( - List[Tuple[int, list]], - [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn], + List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], + [(r[0], r[1:6] + (db_to_json(r[6]),)) for r in txn], ) limited = False @@ -631,6 +637,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_id: str, + thread_id: Optional[str], data: JsonDict, stream_id: int, ) -> Optional[int]: @@ -657,12 +664,27 @@ class ReceiptsWorkerStore(SQLBaseStore): # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts if stream_ordering is not None: - sql = ( - "SELECT stream_ordering, event_id FROM events" - " INNER JOIN receipts_linearized AS r USING (event_id, room_id)" - " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" + if thread_id is None: + thread_clause = "r.thread_id IS NULL" + thread_args: Tuple[str, ...] = () + else: + thread_clause = "r.thread_id = ?" + thread_args = (thread_id,) + + sql = f""" + SELECT stream_ordering, event_id FROM events + INNER JOIN receipts_linearized AS r USING (event_id, room_id) + WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ? AND {thread_clause} + """ + txn.execute( + sql, + ( + room_id, + receipt_type, + user_id, + ) + + thread_args, ) - txn.execute(sql, (room_id, receipt_type, user_id)) for so, eid in txn: if int(so) >= stream_ordering: @@ -682,21 +704,28 @@ class ReceiptsWorkerStore(SQLBaseStore): self._receipts_stream_cache.entity_has_changed, room_id, stream_id ) + keyvalues = { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + } + where_clause = "" + if thread_id is None: + where_clause = "thread_id IS NULL" + else: + keyvalues["thread_id"] = thread_id + self.db_pool.simple_upsert_txn( txn, table="receipts_linearized", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, + keyvalues=keyvalues, values={ "stream_id": stream_id, "event_id": event_id, "event_stream_ordering": stream_ordering, "data": json_encoder.encode(data), - "thread_id": None, }, + where_clause=where_clause, # receipts_linearized has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock lock=False, @@ -748,6 +777,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_ids: List[str], + thread_id: Optional[str], data: dict, ) -> Optional[Tuple[int, int]]: """Insert a receipt, either from local client or remote server. @@ -780,6 +810,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type, user_id, linearized_event_id, + thread_id, data, stream_id=stream_id, # Read committed is actually beneficial here because we check for a receipt with @@ -794,7 +825,8 @@ class ReceiptsWorkerStore(SQLBaseStore): now = self._clock.time_msec() logger.debug( - "RR for event %s in %s (%i ms old)", + "Receipt %s for event %s in %s (%i ms old)", + receipt_type, linearized_event_id, room_id, now - event_ts, @@ -807,6 +839,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type, user_id, event_ids, + thread_id, data, ) @@ -821,6 +854,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_ids: List[str], + thread_id: Optional[str], data: JsonDict, ) -> None: assert self._can_write_to_receipts @@ -832,19 +866,26 @@ class ReceiptsWorkerStore(SQLBaseStore): # FIXME: This shouldn't invalidate the whole cache txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,)) + keyvalues = { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + } + where_clause = "" + if thread_id is None: + where_clause = "thread_id IS NULL" + else: + keyvalues["thread_id"] = thread_id + self.db_pool.simple_upsert_txn( txn, table="receipts_graph", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, + keyvalues=keyvalues, values={ "event_ids": json_encoder.encode(event_ids), "data": json_encoder.encode(data), - "thread_id": None, }, + where_clause=where_clause, # receipts_graph has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock lock=False, diff --git a/synapse/types.py b/synapse/types.py index ec44601f54..773f0438d5 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -835,6 +835,7 @@ class ReadReceipt: receipt_type: str user_id: str event_ids: List[str] + thread_id: Optional[str] data: JsonDict diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index a5aa500ef8..f1e357764f 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -49,7 +49,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): sender = self.hs.get_federation_sender() receipt = ReadReceipt( - "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + "room_id", + "m.read", + "user_id", + ["event_id"], + thread_id=None, + data={"ts": 1234}, ) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) @@ -89,7 +94,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): sender = self.hs.get_federation_sender() receipt = ReadReceipt( - "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + "room_id", + "m.read", + "user_id", + ["event_id"], + thread_id=None, + data={"ts": 1234}, ) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) @@ -121,7 +131,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): # send the second RR receipt = ReadReceipt( - "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} + "room_id", + "m.read", + "user_id", + ["other_id"], + thread_id=None, + data={"ts": 1234}, ) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index b17af2725b..af24c4984d 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -447,6 +447,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): receipt_type="m.read", user_id=self.local_user, event_ids=[f"$eventid_{i}"], + thread_id=None, data={}, ) ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 49a21e2e85..efd92793c0 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -171,7 +171,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if send_receipt: self.get_success( self.master_store.insert_receipt( - ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {} + ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {} ) ) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index eb00117845..ede6d0c118 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -33,7 +33,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # tell the master to send a new receipt self.get_success( self.hs.get_datastores().main.insert_receipt( - "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} + "!room:blue", + "m.read", + USER_ID, + ["$event:blue"], + thread_id=None, + data={"a": 1}, ) ) self.replicate() @@ -48,6 +53,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) self.assertEqual("$event:blue", row.event_id) + self.assertIsNone(row.thread_id) self.assertEqual({"a": 1}, row.data) # Now let's disconnect and insert some data. @@ -57,7 +63,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.get_success( self.hs.get_datastores().main.insert_receipt( - "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} + "!room2:blue", + "m.read", + USER_ID, + ["$event2:foo"], + thread_id=None, + data={"a": 2}, ) ) self.replicate() diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index fc43d7edd1..08c74b93e3 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -106,6 +106,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): "m.read", user_id=user_id, event_ids=[event_id], + thread_id=None, data={}, ) ) diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index c89bfff241..9459ee1705 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -131,13 +131,18 @@ class ReceiptTestCase(HomeserverTestCase): # Send public read receipt for the first event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {} ) ) # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.room_id1, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event1_2_id], + None, + {}, ) ) @@ -164,7 +169,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test receipt updating self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {} ) ) res = self.get_success( @@ -180,7 +185,12 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.room_id2, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event2_1_id], + None, + {}, ) ) res = self.get_success( @@ -202,13 +212,18 @@ class ReceiptTestCase(HomeserverTestCase): # Send public read receipt for the first event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {} ) ) # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.room_id1, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event1_2_id], + None, + {}, ) ) @@ -241,7 +256,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test receipt updating self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {} ) ) res = self.get_success( @@ -259,7 +274,12 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.room_id2, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event2_1_id], + None, + {}, ) ) res = self.get_success( -- cgit 1.5.1 From 7b7478e8b65cceb9e7362c6c1cb932b569a6f383 Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 5 Oct 2022 10:12:48 -0700 Subject: Batch up notifications after event persistence (#14033) --- changelog.d/14033.misc | 1 + synapse/handlers/federation_event.py | 4 +- synapse/handlers/message.py | 25 ++++++------ synapse/notifier.py | 75 ++++++++++++++++++++---------------- synapse/replication/tcp/client.py | 19 ++++----- 5 files changed, 66 insertions(+), 58 deletions(-) create mode 100644 changelog.d/14033.misc (limited to 'synapse/replication/tcp/client.py') diff --git a/changelog.d/14033.misc b/changelog.d/14033.misc new file mode 100644 index 0000000000..fe42852aa5 --- /dev/null +++ b/changelog.d/14033.misc @@ -0,0 +1 @@ +Don't repeatedly wake up the same users for batched events. \ No newline at end of file diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 778d8869b3..da319943cc 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -2240,8 +2240,8 @@ class FederationEventHandler: event_pos = PersistedEventPosition( self._instance_name, event.internal_metadata.stream_ordering ) - await self._notifier.on_new_room_event( - event, event_pos, max_stream_token, extra_users=extra_users + await self._notifier.on_new_room_events( + [(event, event_pos)], max_stream_token, extra_users=extra_users ) if event.type == EventTypes.Member and event.membership == Membership.JOIN: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 00e7645ba5..da1acea275 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1872,6 +1872,7 @@ class EventCreationHandler: events_and_context, backfilled=backfilled ) + events_and_pos = [] for event in persisted_events: if self._ephemeral_events_enabled: # If there's an expiry timestamp on the event, schedule its expiry. @@ -1880,25 +1881,23 @@ class EventCreationHandler: stream_ordering = event.internal_metadata.stream_ordering assert stream_ordering is not None pos = PersistedEventPosition(self._instance_name, stream_ordering) - - async def _notify() -> None: - try: - await self.notifier.on_new_room_event( - event, pos, max_stream_token, extra_users=extra_users - ) - except Exception: - logger.exception( - "Error notifying about new room event %s", - event.event_id, - ) - - run_in_background(_notify) + events_and_pos.append((event, pos)) if event.type == EventTypes.Message: # We don't want to block sending messages on any presence code. This # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) + async def _notify() -> None: + try: + await self.notifier.on_new_room_events( + events_and_pos, max_stream_token, extra_users=extra_users + ) + except Exception: + logger.exception("Error notifying about new room events") + + run_in_background(_notify) + return persisted_events[-1] async def _maybe_kick_guest_users( diff --git a/synapse/notifier.py b/synapse/notifier.py index c42bb8266a..26b97cf766 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -294,35 +294,31 @@ class Notifier: """ self._new_join_in_room_callbacks.append(cb) - async def on_new_room_event( + async def on_new_room_events( self, - event: EventBase, - event_pos: PersistedEventPosition, + events_and_pos: List[Tuple[EventBase, PersistedEventPosition]], max_room_stream_token: RoomStreamToken, extra_users: Optional[Collection[UserID]] = None, ) -> None: - """Unwraps event and calls `on_new_room_event_args`.""" - await self.on_new_room_event_args( - event_pos=event_pos, - room_id=event.room_id, - event_id=event.event_id, - event_type=event.type, - state_key=event.get("state_key"), - membership=event.content.get("membership"), - max_room_stream_token=max_room_stream_token, - extra_users=extra_users or [], - ) + """Creates a _PendingRoomEventEntry for each of the listed events and calls + notify_new_room_events with the results.""" + event_entries = [] + for event, pos in events_and_pos: + entry = self.create_pending_room_event_entry( + pos, + extra_users, + event.room_id, + event.type, + event.get("state_key"), + event.content.get("membership"), + ) + event_entries.append((entry, event.event_id)) + await self.notify_new_room_events(event_entries, max_room_stream_token) - async def on_new_room_event_args( + async def notify_new_room_events( self, - room_id: str, - event_id: str, - event_type: str, - state_key: Optional[str], - membership: Optional[str], - event_pos: PersistedEventPosition, + event_entries: List[Tuple[_PendingRoomEventEntry, str]], max_room_stream_token: RoomStreamToken, - extra_users: Optional[Collection[UserID]] = None, ) -> None: """Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -338,22 +334,33 @@ class Notifier: until all previous events have been persisted before notifying the client streams. """ - self.pending_new_room_events.append( - _PendingRoomEventEntry( - event_pos=event_pos, - extra_users=extra_users or [], - room_id=room_id, - type=event_type, - state_key=state_key, - membership=membership, - ) - ) - self._notify_pending_new_room_events(max_room_stream_token) + for event_entry, event_id in event_entries: + self.pending_new_room_events.append(event_entry) + await self._third_party_rules.on_new_event(event_id) - await self._third_party_rules.on_new_event(event_id) + self._notify_pending_new_room_events(max_room_stream_token) self.notify_replication() + def create_pending_room_event_entry( + self, + event_pos: PersistedEventPosition, + extra_users: Optional[Collection[UserID]], + room_id: str, + event_type: str, + state_key: Optional[str], + membership: Optional[str], + ) -> _PendingRoomEventEntry: + """Creates and returns a _PendingRoomEventEntry""" + return _PendingRoomEventEntry( + event_pos=event_pos, + extra_users=extra_users or [], + room_id=room_id, + type=event_type, + state_key=state_key, + membership=membership, + ) + def _notify_pending_new_room_events( self, max_room_stream_token: RoomStreamToken ) -> None: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index b2522f98ca..18252a2958 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -210,15 +210,16 @@ class ReplicationDataHandler: max_token = self.store.get_room_max_token() event_pos = PersistedEventPosition(instance_name, token) - await self.notifier.on_new_room_event_args( - event_pos=event_pos, - max_room_stream_token=max_token, - extra_users=extra_users, - room_id=row.data.room_id, - event_id=row.data.event_id, - event_type=row.data.type, - state_key=row.data.state_key, - membership=row.data.membership, + event_entry = self.notifier.create_pending_room_event_entry( + event_pos, + extra_users, + row.data.room_id, + row.data.type, + row.data.state_key, + row.data.membership, + ) + await self.notifier.notify_new_room_events( + [(event_entry, row.data.event_id)], max_token ) # If this event is a join, make a note of it so we have an accurate -- cgit 1.5.1