From ab18441573dc14cea1fe4082b2a89b9d392a4b9f Mon Sep 17 00:00:00 2001 From: Šimon Brandner Date: Fri, 5 Aug 2022 17:09:33 +0200 Subject: Support stable identifiers for MSC2285: private read receipts. (#13273) This adds support for the stable identifiers of MSC2285 while continuing to support the unstable identifiers behind the configuration flag. These will be removed in a future version. --- synapse/rest/client/versions.py | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse/rest/client/versions.py') diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 0366986755..c9a830cbac 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -94,6 +94,7 @@ class VersionsRestServlet(RestServlet): # Supports the busy presence state described in MSC3026. "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, # Supports receiving private read receipts as per MSC2285 + "org.matrix.msc2285.stable": True, # TODO: Remove when MSC2285 becomes a part of the spec "org.matrix.msc2285": self.config.experimental.msc2285_enabled, # Supports filtering of /publicRooms by room type as per MSC3827 "org.matrix.msc3827.stable": True, -- cgit 1.5.1 From 0e99f07952edcb6396654e34da50ddeb0a211067 Mon Sep 17 00:00:00 2001 From: Šimon Brandner Date: Thu, 1 Sep 2022 14:31:54 +0200 Subject: Remove support for unstable private read receipts (#13653) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- changelog.d/13653.removal | 1 + synapse/api/constants.py | 1 - synapse/config/experimental.py | 3 -- synapse/handlers/receipts.py | 29 +++---------- synapse/replication/tcp/client.py | 5 +-- synapse/rest/client/notifications.py | 1 - synapse/rest/client/read_marker.py | 2 - synapse/rest/client/receipts.py | 2 - synapse/rest/client/versions.py | 1 - .../storage/databases/main/event_push_actions.py | 2 - tests/handlers/test_receipts.py | 48 ++++++---------------- tests/rest/client/test_sync.py | 37 +++++------------ tests/storage/test_receipts.py | 34 ++++++--------- 13 files changed, 44 insertions(+), 122 deletions(-) create mode 100644 changelog.d/13653.removal (limited to 'synapse/rest/client/versions.py') diff --git a/changelog.d/13653.removal b/changelog.d/13653.removal new file mode 100644 index 0000000000..eb075d4517 --- /dev/null +++ b/changelog.d/13653.removal @@ -0,0 +1 @@ +Remove support for unstable [private read receipts](https://github.com/matrix-org/matrix-spec-proposals/pull/2285). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index c73aea622a..c178ddf070 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -258,7 +258,6 @@ class GuestAccess: class ReceiptTypes: READ: Final = "m.read" READ_PRIVATE: Final = "m.read.private" - UNSTABLE_READ_PRIVATE: Final = "org.matrix.msc2285.read.private" FULLY_READ: Final = "m.fully_read" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index c1ff417539..260db49cad 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -32,9 +32,6 @@ class ExperimentalConfig(Config): # MSC2716 (importing historical messages) self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) - # MSC2285 (unstable private read receipts) - self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False) - # MSC3244 (room version capabilities) self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index d4a866b346..d2bdb9c8be 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -163,10 +163,7 @@ class ReceiptsHandler: if not is_new: return - if self.federation_sender and receipt_type not in ( - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ): + if self.federation_sender and receipt_type != ReceiptTypes.READ_PRIVATE: await self.federation_sender.send_read_receipt(receipt) @@ -206,38 +203,24 @@ class ReceiptEventSource(EventSource[int, JsonDict]): for event_id, orig_event_content in room.get("content", {}).items(): event_content = orig_event_content # If there are private read receipts, additional logic is necessary. - if ( - ReceiptTypes.READ_PRIVATE in event_content - or ReceiptTypes.UNSTABLE_READ_PRIVATE in event_content - ): + if ReceiptTypes.READ_PRIVATE in event_content: # Make a copy without private read receipts to avoid leaking # other user's private read receipts.. event_content = { receipt_type: receipt_value for receipt_type, receipt_value in event_content.items() - if receipt_type - not in ( - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ) + if receipt_type != ReceiptTypes.READ_PRIVATE } # Copy the current user's private read receipt from the # original content, if it exists. - user_private_read_receipt = orig_event_content.get( - ReceiptTypes.READ_PRIVATE, {} - ).get(user_id, None) + user_private_read_receipt = orig_event_content[ + ReceiptTypes.READ_PRIVATE + ].get(user_id, None) if user_private_read_receipt: event_content[ReceiptTypes.READ_PRIVATE] = { user_id: user_private_read_receipt } - user_unstable_private_read_receipt = orig_event_content.get( - ReceiptTypes.UNSTABLE_READ_PRIVATE, {} - ).get(user_id, None) - if user_unstable_private_read_receipt: - event_content[ReceiptTypes.UNSTABLE_READ_PRIVATE] = { - user_id: user_unstable_private_read_receipt - } # Include the event if there is at least one non-private read # receipt or the current user has a private read receipt. diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 1ed7230e32..e4f2201c92 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -416,10 +416,7 @@ class FederationSenderHandler: if not self._is_mine_id(receipt.user_id): continue # Private read receipts never get sent over federation. - if receipt.receipt_type in ( - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, - ): + if receipt.receipt_type == ReceiptTypes.READ_PRIVATE: continue receipt_info = ReadReceipt( receipt.room_id, diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index a73322a6a4..61268e3af1 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -62,7 +62,6 @@ class NotificationsServlet(RestServlet): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ], ) diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index aaad8b233f..5e53096539 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -45,8 +45,6 @@ class ReadMarkerRestServlet(RestServlet): ReceiptTypes.FULLY_READ, ReceiptTypes.READ_PRIVATE, } - if hs.config.experimental.msc2285_enabled: - self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE) async def on_POST( self, request: SynapseRequest, room_id: str diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index c6108fc5eb..5b7fad7402 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -49,8 +49,6 @@ class ReceiptRestServlet(RestServlet): ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ, } - if hs.config.experimental.msc2285_enabled: - self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE) async def on_POST( self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c9a830cbac..c516cda95d 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -95,7 +95,6 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, # Supports receiving private read receipts as per MSC2285 "org.matrix.msc2285.stable": True, # TODO: Remove when MSC2285 becomes a part of the spec - "org.matrix.msc2285": self.config.experimental.msc2285_enabled, # Supports filtering of /publicRooms by room type as per MSC3827 "org.matrix.msc3827.stable": True, # Adds support for importing historical messages as per MSC2716 diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 9f410d69de..f4a07de2a3 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -274,7 +274,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas receipt_types=( ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ), ) @@ -468,7 +467,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ( ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ), ) diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 5f70a2db79..b55238650c 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -15,8 +15,6 @@ from copy import deepcopy from typing import List -from parameterized import parameterized - from synapse.api.constants import EduTypes, ReceiptTypes from synapse.types import JsonDict @@ -27,16 +25,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.event_source = hs.get_event_sources().sources.receipt - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_filters_out_private_receipt(self, receipt_type: str) -> None: + def test_filters_out_private_receipt(self) -> None: self._test_filters_private( [ { "content": { "$1435641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, } @@ -50,18 +45,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): [], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_filters_out_private_receipt_and_ignores_rest( - self, receipt_type: str - ) -> None: + def test_filters_out_private_receipt_and_ignores_rest(self) -> None: self._test_filters_private( [ { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -94,18 +84,15 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest( - self, receipt_type: str + self, ) -> None: self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -175,18 +162,15 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest( - self, receipt_type: str + self, ) -> None: self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, }, @@ -262,16 +246,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_leaves_our_private_and_their_public(self, receipt_type: str) -> None: + def test_leaves_our_private_and_their_public(self) -> None: self._test_filters_private( [ { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@me:server.org": { "ts": 1436451550453, }, @@ -296,7 +277,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@me:server.org": { "ts": 1436451550453, }, @@ -319,16 +300,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_we_do_not_mutate(self, receipt_type: str) -> None: + def test_we_do_not_mutate(self) -> None: """Ensure the input values are not modified.""" events = [ { "content": { "$1435641916114394fHBLK:matrix.org": { - receipt_type: { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, } diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index de0dec8539..0af643ecd9 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -391,7 +391,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() - config["experimental_features"] = {"msc2285_enabled": True} return self.setup_test_homeserver(config=config) @@ -413,17 +412,14 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Join the second user self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_private_read_receipts(self, receipt_type: str) -> None: + def test_private_read_receipts(self) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) # Send a private read receipt to tell the server the first user's message was read channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -432,10 +428,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that the first user can't see the other user's private read receipt self.assertIsNone(self._get_read_receipt()) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_public_receipt_can_override_private(self, receipt_type: str) -> None: + def test_public_receipt_can_override_private(self) -> None: """ Sending a public read receipt to the same event which has a private read receipt should cause that receipt to become public. @@ -446,7 +439,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Send a private read receipt channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -465,10 +458,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that we did override the private read receipt self.assertNotEqual(self._get_read_receipt(), None) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_private_receipt_cannot_override_public(self, receipt_type: str) -> None: + def test_private_receipt_cannot_override_public(self) -> None: """ Sending a private read receipt to the same event which has a public read receipt should cause no change. @@ -489,7 +479,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Send a private read receipt channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -554,7 +544,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): config = super().default_config() config["experimental_features"] = { "msc2654_enabled": True, - "msc2285_enabled": True, } return config @@ -601,10 +590,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): tok=self.tok, ) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_unread_counts(self, receipt_type: str) -> None: + def test_unread_counts(self) -> None: """Tests that /sync returns the right value for the unread count (MSC2654).""" # Check that our own messages don't increase the unread count. @@ -638,7 +624,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Send a read receipt to tell the server we've read the latest event. channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok, ) @@ -726,7 +712,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}", {}, access_token=self.tok, ) @@ -738,7 +724,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ] ) def test_read_receipts_only_go_down(self, receipt_type: str) -> None: @@ -752,7 +737,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Read last event channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}", {}, access_token=self.tok, ) @@ -763,7 +748,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # read receipt go up to an older event channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res1['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res1['event_id']}", {}, access_token=self.tok, ) diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index 191c957fb5..c89bfff241 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from parameterized import parameterized from synapse.api.constants import ReceiptTypes from synapse.types import UserID, create_requester @@ -92,7 +91,6 @@ class ReceiptTestCase(HomeserverTestCase): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ], ) ) @@ -104,7 +102,6 @@ class ReceiptTestCase(HomeserverTestCase): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ], ) ) @@ -117,16 +114,12 @@ class ReceiptTestCase(HomeserverTestCase): [ ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE, - ReceiptTypes.UNSTABLE_READ_PRIVATE, ], ) ) self.assertEqual(res, None) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_get_receipts_for_user(self, receipt_type: str) -> None: + def test_get_receipts_for_user(self) -> None: # Send some events into the first room event1_1_id = self.create_and_send_event( self.room_id1, UserID.from_string(OTHER_USER_ID) @@ -144,14 +137,14 @@ class ReceiptTestCase(HomeserverTestCase): # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} ) ) # Test we get the latest event when we want both private and public receipts res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, receipt_type] + OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) ) self.assertEqual(res, {self.room_id1: event1_2_id}) @@ -164,7 +157,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test we get the latest event when we want only the public receipt res = self.get_success( - self.store.get_receipts_for_user(OUR_USER_ID, [receipt_type]) + self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ_PRIVATE]) ) self.assertEqual(res, {self.room_id1: event1_2_id}) @@ -187,20 +180,17 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {} + self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( self.store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, receipt_type] + OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) ) self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id}) - @parameterized.expand( - [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] - ) - def test_get_last_receipt_event_id_for_user(self, receipt_type: str) -> None: + def test_get_last_receipt_event_id_for_user(self) -> None: # Send some events into the first room event1_1_id = self.create_and_send_event( self.room_id1, UserID.from_string(OTHER_USER_ID) @@ -218,7 +208,7 @@ class ReceiptTestCase(HomeserverTestCase): # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} ) ) @@ -227,7 +217,7 @@ class ReceiptTestCase(HomeserverTestCase): self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, - [ReceiptTypes.READ, receipt_type], + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], ) ) self.assertEqual(res, event1_2_id) @@ -243,7 +233,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test we get the latest event when we want only the private receipt res = self.get_success( self.store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [receipt_type] + OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] ) ) self.assertEqual(res, event1_2_id) @@ -269,14 +259,14 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {} + self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id2, - [ReceiptTypes.READ, receipt_type], + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], ) ) self.assertEqual(res, event2_1_id) -- cgit 1.5.1 From 0fd2f2d46064efd37284a36d5b478815d69ddd96 Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Wed, 21 Sep 2022 16:12:29 +0100 Subject: Implementation of MSC3882 login token request (#13722) --- changelog.d/13722.feature | 1 + synapse/config/experimental.py | 7 ++ synapse/rest/__init__.py | 2 + synapse/rest/client/login_token_request.py | 94 ++++++++++++++++++ synapse/rest/client/versions.py | 2 + tests/rest/client/test_login_token_request.py | 132 ++++++++++++++++++++++++++ 6 files changed, 238 insertions(+) create mode 100644 changelog.d/13722.feature create mode 100644 synapse/rest/client/login_token_request.py create mode 100644 tests/rest/client/test_login_token_request.py (limited to 'synapse/rest/client/versions.py') diff --git a/changelog.d/13722.feature b/changelog.d/13722.feature new file mode 100644 index 0000000000..588d143c0f --- /dev/null +++ b/changelog.d/13722.feature @@ -0,0 +1 @@ +Experimental implementation of MSC3882 to allow an existing device/session to generate a login token for use on a new device/session. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f4541a8db0..bf27f6c101 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -96,3 +96,10 @@ class ExperimentalConfig(Config): # MSC3881: Remotely toggle push notifications for another client self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False) + + # MSC3882: Allow an existing session to sign in a new session + self.msc3882_enabled: bool = experimental.get("msc3882_enabled", False) + self.msc3882_ui_auth: bool = experimental.get("msc3882_ui_auth", True) + self.msc3882_token_timeout = self.parse_duration( + experimental.get("msc3882_token_timeout", "5m") + ) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index b712215112..9a2ab99ede 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -30,6 +30,7 @@ from synapse.rest.client import ( keys, knock, login as v1_login, + login_token_request, logout, mutual_rooms, notifications, @@ -130,3 +131,4 @@ class ClientRestResource(JsonResource): # unstable mutual_rooms.register_servlets(hs, client_resource) + login_token_request.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py new file mode 100644 index 0000000000..ca5c54bf17 --- /dev/null +++ b/synapse/rest/client/login_token_request.py @@ -0,0 +1,94 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Tuple + +from synapse.http.server import HttpServer +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns, interactive_auth_handler +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class LoginTokenRequestServlet(RestServlet): + """ + Get a token that can be used with `m.login.token` to log in a second device. + + Request: + + POST /login/token HTTP/1.1 + Content-Type: application/json + + {} + + Response: + + HTTP/1.1 200 OK + { + "login_token": "ABDEFGH", + "expires_in": 3600, + } + """ + + PATTERNS = client_patterns("/login/token$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self.clock = hs.get_clock() + self.server_name = hs.config.server.server_name + self.macaroon_gen = hs.get_macaroon_generator() + self.auth_handler = hs.get_auth_handler() + self.token_timeout = hs.config.experimental.msc3882_token_timeout + self.ui_auth = hs.config.experimental.msc3882_ui_auth + + @interactive_auth_handler + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + body = parse_json_object_from_request(request) + + if self.ui_auth: + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "issue a new access token for your account", + can_skip_ui_auth=False, # Don't allow skipping of UI auth + ) + + login_token = self.macaroon_gen.generate_short_term_login_token( + user_id=requester.user.to_string(), + auth_provider_id="org.matrix.msc3882.login_token_request", + duration_in_ms=self.token_timeout, + ) + + return ( + 200, + { + "login_token": login_token, + "expires_in": self.token_timeout // 1000, + }, + ) + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.experimental.msc3882_enabled: + LoginTokenRequestServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c516cda95d..c3488f4330 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -105,6 +105,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, + # Adds support for login token requests as per MSC3882 + "org.matrix.msc3882": self.config.experimental.msc3882_enabled, }, }, ) diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py new file mode 100644 index 0000000000..d5bb16c98d --- /dev/null +++ b/tests/rest/client/test_login_token_request.py @@ -0,0 +1,132 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin +from synapse.rest.client import login, login_token_request +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.unittest import override_config + + +class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + admin.register_servlets, + login_token_request.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.hs = self.setup_test_homeserver() + self.hs.config.registration.enable_registration = True + self.hs.config.registration.registrations_require_3pid = [] + self.hs.config.registration.auto_join_rooms = [] + self.hs.config.captcha.enable_registration_captcha = False + + return self.hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.user = "user123" + self.password = "password" + + def test_disabled(self) -> None: + channel = self.make_request("POST", "/login/token", {}, access_token=None) + self.assertEqual(channel.code, 400) + + self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 400) + + @override_config({"experimental_features": {"msc3882_enabled": True}}) + def test_require_auth(self) -> None: + channel = self.make_request("POST", "/login/token", {}, access_token=None) + self.assertEqual(channel.code, 401) + + @override_config({"experimental_features": {"msc3882_enabled": True}}) + def test_uia_on(self) -> None: + user_id = self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 401) + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + session = channel.json_body["session"] + + uia = { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.password, + "session": session, + }, + } + + channel = self.make_request("POST", "/login/token", uia, access_token=token) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["expires_in"], 300) + + login_token = channel.json_body["login_token"] + + channel = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["user_id"], user_id) + + @override_config( + {"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}} + ) + def test_uia_off(self) -> None: + user_id = self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["expires_in"], 300) + + login_token = channel.json_body["login_token"] + + channel = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["user_id"], user_id) + + @override_config( + { + "experimental_features": { + "msc3882_enabled": True, + "msc3882_ui_auth": False, + "msc3882_token_timeout": "15s", + } + } + ) + def test_expires_in(self) -> None: + self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["expires_in"], 15) -- cgit 1.5.1 From efabf44c7652095a0e3d9d9083fc8359cdde3854 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 21 Sep 2022 17:18:44 +0100 Subject: Add version flag for MSC3881 (#13860) --- changelog.d/13860.feature | 1 + synapse/rest/client/versions.py | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 changelog.d/13860.feature (limited to 'synapse/rest/client/versions.py') diff --git a/changelog.d/13860.feature b/changelog.d/13860.feature new file mode 100644 index 0000000000..6c8e5cffe2 --- /dev/null +++ b/changelog.d/13860.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881). diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c3488f4330..b3917a5abc 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -107,6 +107,8 @@ class VersionsRestServlet(RestServlet): "fi.mau.msc2815": self.config.experimental.msc2815_enabled, # Adds support for login token requests as per MSC3882 "org.matrix.msc3882": self.config.experimental.msc3882_enabled, + # Adds support for remotely enabling/disabling pushers, as per MSC3881 + "org.matrix.msc3881": self.config.experimental.msc3881_enabled, }, }, ) -- cgit 1.5.1 From efd108b45d1706526416bc9a6f89463b5ff4506a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 23 Sep 2022 10:33:28 -0400 Subject: Accept & store thread IDs for receipts (implement MSC3771). (#13782) Updates the `/receipts` endpoint and receipt EDU handler to parse a `thread_id` from the body and insert it in the database. --- changelog.d/13782.feature | 1 + synapse/config/experimental.py | 2 + synapse/handlers/receipts.py | 23 ++++++- synapse/replication/tcp/client.py | 3 +- synapse/replication/tcp/streams/_base.py | 1 + synapse/rest/client/read_marker.py | 2 + synapse/rest/client/receipts.py | 14 ++++- synapse/rest/client/versions.py | 2 + synapse/storage/database.py | 2 + synapse/storage/databases/main/receipts.py | 87 +++++++++++++++++++------- synapse/types.py | 1 + tests/federation/test_federation_sender.py | 21 ++++++- tests/handlers/test_appservice.py | 1 + tests/replication/slave/storage/test_events.py | 2 +- tests/replication/tcp/streams/test_receipts.py | 15 ++++- tests/storage/test_event_push_actions.py | 1 + tests/storage/test_receipts.py | 36 ++++++++--- 17 files changed, 173 insertions(+), 41 deletions(-) create mode 100644 changelog.d/13782.feature (limited to 'synapse/rest/client/versions.py') diff --git a/changelog.d/13782.feature b/changelog.d/13782.feature new file mode 100644 index 0000000000..d0cb902dff --- /dev/null +++ b/changelog.d/13782.feature @@ -0,0 +1 @@ +Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 595eb007a5..933779c23a 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -83,6 +83,8 @@ class ExperimentalConfig(Config): # MSC3786 (Add a default push rule to ignore m.room.server_acl events) self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False) + # MSC3771: Thread read receipts + self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False) # MSC3772: A push rule for mutual relations. self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index afaf3261df..4768a34c07 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -63,6 +63,8 @@ class ReceiptsHandler: self.clock = self.hs.get_clock() self.state = hs.get_state_handler() + self._msc3771_enabled = hs.config.experimental.msc3771_enabled + async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None: """Called when we receive an EDU of type m.receipt from a remote HS.""" receipts = [] @@ -91,13 +93,23 @@ class ReceiptsHandler: ) continue + # Check if these receipts apply to a thread. + thread_id = None + data = user_values.get("data", {}) + if self._msc3771_enabled and isinstance(data, dict): + thread_id = data.get("thread_id") + # If the thread ID is invalid, consider it missing. + if not isinstance(thread_id, str): + thread_id = None + receipts.append( ReadReceipt( room_id=room_id, receipt_type=receipt_type, user_id=user_id, event_ids=user_values["event_ids"], - data=user_values.get("data", {}), + thread_id=thread_id, + data=data, ) ) @@ -114,6 +126,7 @@ class ReceiptsHandler: receipt.receipt_type, receipt.user_id, receipt.event_ids, + receipt.thread_id, receipt.data, ) @@ -146,7 +159,12 @@ class ReceiptsHandler: return True async def received_client_receipt( - self, room_id: str, receipt_type: str, user_id: str, event_id: str + self, + room_id: str, + receipt_type: str, + user_id: str, + event_id: str, + thread_id: Optional[str], ) -> None: """Called when a client tells us a local user has read up to the given event_id in the room. @@ -156,6 +174,7 @@ class ReceiptsHandler: receipt_type=receipt_type, user_id=user_id, event_ids=[event_id], + thread_id=thread_id, data={"ts": int(self.clock.time_msec())}, ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index cf9cd6833b..b2522f98ca 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -427,7 +427,8 @@ class FederationSenderHandler: receipt.receipt_type, receipt.user_id, [receipt.event_id], - receipt.data, + thread_id=receipt.thread_id, + data=receipt.data, ) await self.federation_sender.send_read_receipt(receipt_info) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 398bebeaa6..e01155ad59 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -361,6 +361,7 @@ class ReceiptsStream(Stream): receipt_type: str user_id: str event_id: str + thread_id: Optional[str] data: dict NAME = "receipts" diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 5e53096539..852838515c 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -83,6 +83,8 @@ class ReadMarkerRestServlet(RestServlet): receipt_type, user_id=requester.user.to_string(), event_id=event_id, + # Setting the thread ID is not possible with the /read_markers endpoint. + thread_id=None, ) return 200, {} diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 5b7fad7402..f3ff156abe 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -49,6 +49,7 @@ class ReceiptRestServlet(RestServlet): ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ, } + self._msc3771_enabled = hs.config.experimental.msc3771_enabled async def on_POST( self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str @@ -61,7 +62,17 @@ class ReceiptRestServlet(RestServlet): f"Receipt type must be {', '.join(self._known_receipt_types)}", ) - parse_json_object_from_request(request, allow_empty_body=False) + body = parse_json_object_from_request(request) + + # Pull the thread ID, if one exists. + thread_id = None + if self._msc3771_enabled: + if "thread_id" in body: + thread_id = body.get("thread_id") + if not thread_id or not isinstance(thread_id, str): + raise SynapseError( + 400, "thread_id field must be a non-empty string" + ) await self.presence_handler.bump_presence_active_time(requester.user) @@ -77,6 +88,7 @@ class ReceiptRestServlet(RestServlet): receipt_type, user_id=requester.user.to_string(), event_id=event_id, + thread_id=thread_id, ) return 200, {} diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index b3917a5abc..c95b0d6f19 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -103,6 +103,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above + # Support for thread read receipts. + "org.matrix.msc3771": self.config.experimental.msc3771_enabled, # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, # Adds support for login token requests as per MSC3882 diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 921cd4dc5e..9d116f6925 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -95,6 +95,8 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx", "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx", "event_push_summary": "event_push_summary_unique_index", + "receipts_linearized": "receipts_linearized_unique_index", + "receipts_graph": "receipts_graph_unique_index", } diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index ddb8e80b69..52fe0db924 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -540,7 +540,9 @@ class ReceiptsWorkerStore(SQLBaseStore): async def get_all_updated_receipts( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, list]], int, bool]: + ) -> Tuple[ + List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], int, bool + ]: """Get updates for receipts replication stream. Args: @@ -567,9 +569,13 @@ class ReceiptsWorkerStore(SQLBaseStore): def get_all_updated_receipts_txn( txn: LoggingTransaction, - ) -> Tuple[List[Tuple[int, list]], int, bool]: + ) -> Tuple[ + List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], + int, + bool, + ]: sql = """ - SELECT stream_id, room_id, receipt_type, user_id, event_id, data + SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC @@ -578,8 +584,8 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id, limit)) updates = cast( - List[Tuple[int, list]], - [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn], + List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], + [(r[0], r[1:6] + (db_to_json(r[6]),)) for r in txn], ) limited = False @@ -631,6 +637,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_id: str, + thread_id: Optional[str], data: JsonDict, stream_id: int, ) -> Optional[int]: @@ -657,12 +664,27 @@ class ReceiptsWorkerStore(SQLBaseStore): # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts if stream_ordering is not None: - sql = ( - "SELECT stream_ordering, event_id FROM events" - " INNER JOIN receipts_linearized AS r USING (event_id, room_id)" - " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" + if thread_id is None: + thread_clause = "r.thread_id IS NULL" + thread_args: Tuple[str, ...] = () + else: + thread_clause = "r.thread_id = ?" + thread_args = (thread_id,) + + sql = f""" + SELECT stream_ordering, event_id FROM events + INNER JOIN receipts_linearized AS r USING (event_id, room_id) + WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ? AND {thread_clause} + """ + txn.execute( + sql, + ( + room_id, + receipt_type, + user_id, + ) + + thread_args, ) - txn.execute(sql, (room_id, receipt_type, user_id)) for so, eid in txn: if int(so) >= stream_ordering: @@ -682,21 +704,28 @@ class ReceiptsWorkerStore(SQLBaseStore): self._receipts_stream_cache.entity_has_changed, room_id, stream_id ) + keyvalues = { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + } + where_clause = "" + if thread_id is None: + where_clause = "thread_id IS NULL" + else: + keyvalues["thread_id"] = thread_id + self.db_pool.simple_upsert_txn( txn, table="receipts_linearized", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, + keyvalues=keyvalues, values={ "stream_id": stream_id, "event_id": event_id, "event_stream_ordering": stream_ordering, "data": json_encoder.encode(data), - "thread_id": None, }, + where_clause=where_clause, # receipts_linearized has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock lock=False, @@ -748,6 +777,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_ids: List[str], + thread_id: Optional[str], data: dict, ) -> Optional[Tuple[int, int]]: """Insert a receipt, either from local client or remote server. @@ -780,6 +810,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type, user_id, linearized_event_id, + thread_id, data, stream_id=stream_id, # Read committed is actually beneficial here because we check for a receipt with @@ -794,7 +825,8 @@ class ReceiptsWorkerStore(SQLBaseStore): now = self._clock.time_msec() logger.debug( - "RR for event %s in %s (%i ms old)", + "Receipt %s for event %s in %s (%i ms old)", + receipt_type, linearized_event_id, room_id, now - event_ts, @@ -807,6 +839,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type, user_id, event_ids, + thread_id, data, ) @@ -821,6 +854,7 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type: str, user_id: str, event_ids: List[str], + thread_id: Optional[str], data: JsonDict, ) -> None: assert self._can_write_to_receipts @@ -832,19 +866,26 @@ class ReceiptsWorkerStore(SQLBaseStore): # FIXME: This shouldn't invalidate the whole cache txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,)) + keyvalues = { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + } + where_clause = "" + if thread_id is None: + where_clause = "thread_id IS NULL" + else: + keyvalues["thread_id"] = thread_id + self.db_pool.simple_upsert_txn( txn, table="receipts_graph", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, + keyvalues=keyvalues, values={ "event_ids": json_encoder.encode(event_ids), "data": json_encoder.encode(data), - "thread_id": None, }, + where_clause=where_clause, # receipts_graph has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock lock=False, diff --git a/synapse/types.py b/synapse/types.py index ec44601f54..773f0438d5 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -835,6 +835,7 @@ class ReadReceipt: receipt_type: str user_id: str event_ids: List[str] + thread_id: Optional[str] data: JsonDict diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index a5aa500ef8..f1e357764f 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -49,7 +49,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): sender = self.hs.get_federation_sender() receipt = ReadReceipt( - "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + "room_id", + "m.read", + "user_id", + ["event_id"], + thread_id=None, + data={"ts": 1234}, ) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) @@ -89,7 +94,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): sender = self.hs.get_federation_sender() receipt = ReadReceipt( - "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + "room_id", + "m.read", + "user_id", + ["event_id"], + thread_id=None, + data={"ts": 1234}, ) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) @@ -121,7 +131,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): # send the second RR receipt = ReadReceipt( - "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} + "room_id", + "m.read", + "user_id", + ["other_id"], + thread_id=None, + data={"ts": 1234}, ) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index b17af2725b..af24c4984d 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -447,6 +447,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): receipt_type="m.read", user_id=self.local_user, event_ids=[f"$eventid_{i}"], + thread_id=None, data={}, ) ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 49a21e2e85..efd92793c0 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -171,7 +171,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if send_receipt: self.get_success( self.master_store.insert_receipt( - ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {} + ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {} ) ) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index eb00117845..ede6d0c118 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -33,7 +33,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # tell the master to send a new receipt self.get_success( self.hs.get_datastores().main.insert_receipt( - "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} + "!room:blue", + "m.read", + USER_ID, + ["$event:blue"], + thread_id=None, + data={"a": 1}, ) ) self.replicate() @@ -48,6 +53,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) self.assertEqual("$event:blue", row.event_id) + self.assertIsNone(row.thread_id) self.assertEqual({"a": 1}, row.data) # Now let's disconnect and insert some data. @@ -57,7 +63,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.get_success( self.hs.get_datastores().main.insert_receipt( - "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} + "!room2:blue", + "m.read", + USER_ID, + ["$event2:foo"], + thread_id=None, + data={"a": 2}, ) ) self.replicate() diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index fc43d7edd1..08c74b93e3 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -106,6 +106,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): "m.read", user_id=user_id, event_ids=[event_id], + thread_id=None, data={}, ) ) diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index c89bfff241..9459ee1705 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -131,13 +131,18 @@ class ReceiptTestCase(HomeserverTestCase): # Send public read receipt for the first event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {} ) ) # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.room_id1, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event1_2_id], + None, + {}, ) ) @@ -164,7 +169,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test receipt updating self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {} ) ) res = self.get_success( @@ -180,7 +185,12 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.room_id2, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event2_1_id], + None, + {}, ) ) res = self.get_success( @@ -202,13 +212,18 @@ class ReceiptTestCase(HomeserverTestCase): # Send public read receipt for the first event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {} ) ) # Send private read receipt for the second event self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.room_id1, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event1_2_id], + None, + {}, ) ) @@ -241,7 +256,7 @@ class ReceiptTestCase(HomeserverTestCase): # Test receipt updating self.get_success( self.store.insert_receipt( - self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {} ) ) res = self.get_success( @@ -259,7 +274,12 @@ class ReceiptTestCase(HomeserverTestCase): # Test new room is reflected in what the method returns self.get_success( self.store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.room_id2, + ReceiptTypes.READ_PRIVATE, + OUR_USER_ID, + [event2_1_id], + None, + {}, ) ) res = self.get_success( -- cgit 1.5.1 From b4ec4f5e71a87d5bdc840a4220dfd9a34c54c847 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Oct 2022 09:47:04 -0400 Subject: Track notification counts per thread (implement MSC3773). (#13776) When retrieving counts of notifications segment the results based on the thread ID, but choose whether to return them as individual threads or as a single summed field by letting the client opt-in via a sync flag. The summarization code is also updated to be per thread, instead of per room. --- changelog.d/13776.feature | 1 + synapse/api/constants.py | 3 + synapse/api/filtering.py | 10 ++ synapse/config/experimental.py | 2 + synapse/handlers/sync.py | 40 ++++- synapse/push/bulk_push_rule_evaluator.py | 4 +- synapse/push/push_tools.py | 9 +- synapse/rest/client/sync.py | 4 + synapse/rest/client/versions.py | 3 +- synapse/storage/database.py | 2 +- .../storage/databases/main/event_push_actions.py | 188 +++++++++++++-------- synapse/storage/schema/__init__.py | 6 +- .../delta/73/06thread_notifications_backfill.sql | 29 ++++ .../07thread_notifications_not_null.sql.postgres | 19 +++ .../73/07thread_notifications_not_null.sql.sqlite | 101 +++++++++++ tests/replication/slave/storage/test_events.py | 17 +- tests/storage/test_event_push_actions.py | 169 +++++++++++++++++- 17 files changed, 514 insertions(+), 93 deletions(-) create mode 100644 changelog.d/13776.feature create mode 100644 synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql create mode 100644 synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres create mode 100644 synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite (limited to 'synapse/rest/client/versions.py') diff --git a/changelog.d/13776.feature b/changelog.d/13776.feature new file mode 100644 index 0000000000..22bce125ce --- /dev/null +++ b/changelog.d/13776.feature @@ -0,0 +1 @@ +Experimental support for thread-specific notifications ([MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index c031903b1a..44c5ffc6a5 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -31,6 +31,9 @@ MAX_ALIAS_LENGTH = 255 # the maximum length for a user id is 255 characters MAX_USERID_LENGTH = 255 +# Constant value used for the pseudo-thread which is the main timeline. +MAIN_TIMELINE: Final = "main" + class Membership: diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index f7f46f8d80..c6e44dcf82 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -84,6 +84,7 @@ ROOM_EVENT_FILTER_SCHEMA = { "contains_url": {"type": "boolean"}, "lazy_load_members": {"type": "boolean"}, "include_redundant_members": {"type": "boolean"}, + "org.matrix.msc3773.unread_thread_notifications": {"type": "boolean"}, # Include or exclude events with the provided labels. # cf https://github.com/matrix-org/matrix-doc/pull/2326 "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, @@ -240,6 +241,9 @@ class FilterCollection: def include_redundant_members(self) -> bool: return self._room_state_filter.include_redundant_members + def unread_thread_notifications(self) -> bool: + return self._room_timeline_filter.unread_thread_notifications + async def filter_presence( self, events: Iterable[UserPresenceState] ) -> List[UserPresenceState]: @@ -304,6 +308,12 @@ class Filter: self.include_redundant_members = filter_json.get( "include_redundant_members", False ) + if hs.config.experimental.msc3773_enabled: + self.unread_thread_notifications: bool = filter_json.get( + "org.matrix.msc3773.unread_thread_notifications", False + ) + else: + self.unread_thread_notifications = False self.types = filter_json.get("types", None) self.not_types = filter_json.get("not_types", []) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 83695f24d9..6503ce6e34 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -99,6 +99,8 @@ class ExperimentalConfig(Config): self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False) # MSC3772: A push rule for mutual relations. self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) + # MSC3773: Thread notifications + self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) # MSC3715: dir param on /relations. self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4abb9b6127..329e89c604 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -40,7 +40,7 @@ from synapse.handlers.relations import BundledAggregations from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user -from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.storage.databases.main.event_push_actions import RoomNotifCounts from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -128,6 +128,7 @@ class JoinedSyncResult: ephemeral: List[JsonDict] account_data: List[JsonDict] unread_notifications: JsonDict + unread_thread_notifications: JsonDict summary: Optional[JsonDict] unread_count: int @@ -278,6 +279,8 @@ class SyncHandler: self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync + self._msc3773_enabled = hs.config.experimental.msc3773_enabled + async def wait_for_sync_for_user( self, requester: Requester, @@ -1288,7 +1291,7 @@ class SyncHandler: async def unread_notifs_for_room_id( self, room_id: str, sync_config: SyncConfig - ) -> NotifCounts: + ) -> RoomNotifCounts: with Measure(self.clock, "unread_notifs_for_room_id"): return await self.store.get_unread_event_push_actions_by_room_for_user( @@ -2353,6 +2356,7 @@ class SyncHandler: ephemeral=ephemeral, account_data=account_data_events, unread_notifications=unread_notifications, + unread_thread_notifications={}, summary=summary, unread_count=0, ) @@ -2360,10 +2364,36 @@ class SyncHandler: if room_sync or always_include: notifs = await self.unread_notifs_for_room_id(room_id, sync_config) - unread_notifications["notification_count"] = notifs.notify_count - unread_notifications["highlight_count"] = notifs.highlight_count + # Notifications for the main timeline. + notify_count = notifs.main_timeline.notify_count + highlight_count = notifs.main_timeline.highlight_count + unread_count = notifs.main_timeline.unread_count - room_sync.unread_count = notifs.unread_count + # Check the sync configuration. + if ( + self._msc3773_enabled + and sync_config.filter_collection.unread_thread_notifications() + ): + # And add info for each thread. + room_sync.unread_thread_notifications = { + thread_id: { + "notification_count": thread_notifs.notify_count, + "highlight_count": thread_notifs.highlight_count, + } + for thread_id, thread_notifs in notifs.threads.items() + if thread_id is not None + } + + else: + # Combine the unread counts for all threads and main timeline. + for thread_notifs in notifs.threads.values(): + notify_count += thread_notifs.notify_count + highlight_count += thread_notifs.highlight_count + unread_count += thread_notifs.unread_count + + unread_notifications["notification_count"] = notify_count + unread_notifications["highlight_count"] = highlight_count + room_sync.unread_count = unread_count sync_result_builder.joined.append(room_sync) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 4270438918..61d952742d 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -31,7 +31,7 @@ from typing import ( from prometheus_client import Counter -from synapse.api.constants import EventTypes, Membership, RelationTypes +from synapse.api.constants import MAIN_TIMELINE, EventTypes, Membership, RelationTypes from synapse.event_auth import auth_types_for_event, get_user_power_level from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext @@ -280,7 +280,7 @@ class BulkPushRuleEvaluator: # If the event does not have a relation, then cannot have any mutual # relations or thread ID. relations = {} - thread_id = "main" + thread_id = MAIN_TIMELINE if relation: relations = await self._get_mutual_relations( relation.parent_id, diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 658bf373b7..edeba27a45 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -39,7 +39,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - await concurrently_execute(get_room_unread_count, joins, 10) for notifs in room_notifs: - if notifs.notify_count == 0: + # Combine the counts from all the threads. + notify_count = notifs.main_timeline.notify_count + sum( + n.notify_count for n in notifs.threads.values() + ) + + if notify_count == 0: continue if group_by_room: @@ -47,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - badge += 1 else: # increment the badge count by the number of unread messages in the room - badge += notifs.notify_count + badge += notify_count return badge diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index c2989765ce..f1c23d68e5 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -509,6 +509,10 @@ class SyncRestServlet(RestServlet): ephemeral_events = room.ephemeral result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications + if room.unread_thread_notifications: + result[ + "org.matrix.msc3773.unread_thread_notifications" + ] = room.unread_thread_notifications result["summary"] = room.summary if self._msc2654_enabled: result["org.matrix.msc2654.unread_count"] = room.unread_count diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c95b0d6f19..280d306483 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -103,8 +103,9 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above - # Support for thread read receipts. + # Support for thread read receipts & notification counts. "org.matrix.msc3771": self.config.experimental.msc3771_enabled, + "org.matrix.msc3773": self.config.experimental.msc3773_enabled, # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, # Adds support for login token requests as per MSC3882 diff --git a/synapse/storage/database.py b/synapse/storage/database.py index b4469eb964..7bb21f8f81 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -94,7 +94,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "event_search": "event_search_event_id_idx", "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx", "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx", - "event_push_summary": "event_push_summary_unique_index", + "event_push_summary": "event_push_summary_unique_index2", "receipts_linearized": "receipts_linearized_unique_index", "receipts_graph": "receipts_graph_unique_index", } diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index cdc9ee5a37..3210e9cca1 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -88,7 +88,7 @@ from typing import ( import attr -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -157,7 +157,7 @@ class UserPushAction(EmailPushAction): @attr.s(slots=True, auto_attribs=True) class NotifCounts: """ - The per-user, per-room count of notifications. Used by sync and push. + The per-user, per-room, per-thread count of notifications. Used by sync and push. """ notify_count: int = 0 @@ -165,6 +165,21 @@ class NotifCounts: highlight_count: int = 0 +@attr.s(slots=True, auto_attribs=True) +class RoomNotifCounts: + """ + The per-user, per-room count of notifications. Used by sync and push. + """ + + main_timeline: NotifCounts + # Map of thread ID to the notification counts. + threads: Dict[str, NotifCounts] + + def __len__(self) -> int: + # To properly account for the amount of space in any caches. + return len(self.threads) + 1 + + def _serialize_action( actions: Collection[Union[Mapping, str]], is_highlight: bool ) -> str: @@ -338,12 +353,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return result - @cached(tree=True, max_entries=5000) + @cached(tree=True, max_entries=5000, iterable=True) async def get_unread_event_push_actions_by_room_for_user( self, room_id: str, user_id: str, - ) -> NotifCounts: + ) -> RoomNotifCounts: """Get the notification count, the highlight count and the unread message count for a given user in a given room after their latest read receipt. @@ -356,8 +371,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: The user to retrieve the counts for. Returns - A NotifCounts object containing the notification count, the highlight count - and the unread message count. + A RoomNotifCounts object containing the notification count, the + highlight count and the unread message count for both the main timeline + and threads. """ return await self.db_pool.runInteraction( "get_unread_event_push_actions_by_room", @@ -371,7 +387,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn: LoggingTransaction, room_id: str, user_id: str, - ) -> NotifCounts: + ) -> RoomNotifCounts: # Get the stream ordering of the user's latest receipt in the room. result = self.get_last_unthreaded_receipt_for_user_txn( txn, @@ -406,7 +422,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas room_id: str, user_id: str, receipt_stream_ordering: int, - ) -> NotifCounts: + ) -> RoomNotifCounts: """Get the number of unread messages for a user/room that have happened since the given stream ordering. @@ -418,12 +434,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas receipt in the room. If there are no receipts, the stream ordering of the user's join event. - Returns - A NotifCounts object containing the notification count, the highlight count - and the unread message count. + Returns: + A RoomNotifCounts object containing the notification count, the + highlight count and the unread message count for both the main timeline + and threads. """ - counts = NotifCounts() + main_counts = NotifCounts() + thread_counts: Dict[str, NotifCounts] = {} + + def _get_thread(thread_id: str) -> NotifCounts: + if thread_id == MAIN_TIMELINE: + return main_counts + return thread_counts.setdefault(thread_id, NotifCounts()) # First we pull the counts from the summary table. # @@ -440,52 +463,61 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # receipt). txn.execute( """ - SELECT stream_ordering, notif_count, COALESCE(unread_count, 0) + SELECT stream_ordering, notif_count, COALESCE(unread_count, 0), thread_id FROM event_push_summary WHERE room_id = ? AND user_id = ? AND ( (last_receipt_stream_ordering IS NULL AND stream_ordering > ?) OR last_receipt_stream_ordering = ? - ) + ) AND (notif_count != 0 OR COALESCE(unread_count, 0) != 0) """, (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering), ) - row = txn.fetchone() - - summary_stream_ordering = 0 - if row: - summary_stream_ordering = row[0] - counts.notify_count += row[1] - counts.unread_count += row[2] + max_summary_stream_ordering = 0 + for summary_stream_ordering, notif_count, unread_count, thread_id in txn: + counts = _get_thread(thread_id) + counts.notify_count += notif_count + counts.unread_count += unread_count + + # Summaries will only be used if they have not been invalidated by + # a recent receipt; track the latest stream ordering or a valid summary. + # + # Note that since there's only one read receipt in the room per user, + # valid summaries are contiguous. + max_summary_stream_ordering = max( + summary_stream_ordering, max_summary_stream_ordering + ) # Next we need to count highlights, which aren't summarised sql = """ - SELECT COUNT(*) FROM event_push_actions + SELECT COUNT(*), thread_id FROM event_push_actions WHERE user_id = ? AND room_id = ? AND stream_ordering > ? AND highlight = 1 + GROUP BY thread_id """ txn.execute(sql, (user_id, room_id, receipt_stream_ordering)) - row = txn.fetchone() - if row: - counts.highlight_count += row[0] + for highlight_count, thread_id in txn: + _get_thread(thread_id).highlight_count += highlight_count # Finally we need to count push actions that aren't included in the # summary returned above. This might be due to recent events that haven't # been summarised yet or the summary is out of date due to a recent read # receipt. start_unread_stream_ordering = max( - receipt_stream_ordering, summary_stream_ordering + receipt_stream_ordering, max_summary_stream_ordering ) - notify_count, unread_count = self._get_notif_unread_count_for_user_room( + unread_counts = self._get_notif_unread_count_for_user_room( txn, room_id, user_id, start_unread_stream_ordering ) - counts.notify_count += notify_count - counts.unread_count += unread_count + for notif_count, unread_count, thread_id in unread_counts: + counts = _get_thread(thread_id) + counts.notify_count += notif_count + counts.unread_count += unread_count - return counts + return RoomNotifCounts(main_counts, thread_counts) def _get_notif_unread_count_for_user_room( self, @@ -494,7 +526,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: str, stream_ordering: int, max_stream_ordering: Optional[int] = None, - ) -> Tuple[int, int]: + ) -> List[Tuple[int, int, str]]: """Returns the notify and unread counts from `event_push_actions` for the given user/room in the given range. @@ -510,13 +542,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas If this is not given, then no maximum is applied. Return: - A tuple of the notif count and unread count in the given range. + A tuple of the notif count and unread count in the given range for + each thread. """ # If there have been no events in the room since the stream ordering, # there can't be any push actions either. if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering): - return 0, 0 + return [] clause = "" args = [user_id, room_id, stream_ordering] @@ -527,26 +560,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # If the max stream ordering is less than the min stream ordering, # then obviously there are zero push actions in that range. if max_stream_ordering <= stream_ordering: - return 0, 0 + return [] sql = f""" SELECT COUNT(CASE WHEN notif = 1 THEN 1 END), - COUNT(CASE WHEN unread = 1 THEN 1 END) - FROM event_push_actions ea - WHERE user_id = ? + COUNT(CASE WHEN unread = 1 THEN 1 END), + thread_id + FROM event_push_actions ea + WHERE user_id = ? AND room_id = ? AND ea.stream_ordering > ? {clause} + GROUP BY thread_id """ txn.execute(sql, args) - row = txn.fetchone() - - if row: - return cast(Tuple[int, int], row) - - return 0, 0 + return cast(List[Tuple[int, int, str]], txn.fetchall()) async def get_push_action_users_in_range( self, min_stream_ordering: int, max_stream_ordering: int @@ -1099,26 +1129,34 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # Fetch the notification counts between the stream ordering of the # latest receipt and what was previously summarised. - notif_count, unread_count = self._get_notif_unread_count_for_user_room( + unread_counts = self._get_notif_unread_count_for_user_room( txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering ) - # Replace the previous summary with the new counts. - # - # TODO(threads): Upsert per-thread instead of setting them all to main. - self.db_pool.simple_upsert_txn( + # First mark the summary for all threads in the room as cleared. + self.db_pool.simple_update_txn( txn, table="event_push_summary", - keyvalues={"room_id": room_id, "user_id": user_id}, - values={ - "notif_count": notif_count, - "unread_count": unread_count, + keyvalues={"user_id": user_id, "room_id": room_id}, + updatevalues={ + "notif_count": 0, + "unread_count": 0, "stream_ordering": old_rotate_stream_ordering, "last_receipt_stream_ordering": stream_ordering, - "thread_id": "main", }, ) + # Then any updated threads get their notification count and unread + # count updated. + self.db_pool.simple_update_many_txn( + txn, + table="event_push_summary", + key_names=("room_id", "user_id", "thread_id"), + key_values=[(room_id, user_id, row[2]) for row in unread_counts], + value_names=("notif_count", "unread_count"), + value_values=[(row[0], row[1]) for row in unread_counts], + ) + # We always update `event_push_summary_last_receipt_stream_id` to # ensure that we don't rescan the same receipts for remote users. @@ -1204,23 +1242,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # Calculate the new counts that should be upserted into event_push_summary sql = """ - SELECT user_id, room_id, + SELECT user_id, room_id, thread_id, coalesce(old.%s, 0) + upd.cnt, upd.stream_ordering FROM ( - SELECT user_id, room_id, count(*) as cnt, + SELECT user_id, room_id, thread_id, count(*) as cnt, max(ea.stream_ordering) as stream_ordering FROM event_push_actions AS ea - LEFT JOIN event_push_summary AS old USING (user_id, room_id) + LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id) WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ? AND ( old.last_receipt_stream_ordering IS NULL OR old.last_receipt_stream_ordering < ea.stream_ordering ) AND %s = 1 - GROUP BY user_id, room_id + GROUP BY user_id, room_id, thread_id ) AS upd - LEFT JOIN event_push_summary AS old USING (user_id, room_id) + LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id) """ # First get the count of unread messages. @@ -1234,11 +1272,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # object because we might not have the same amount of rows in each of them. To do # this, we use a dict indexed on the user ID and room ID to make it easier to # populate. - summaries: Dict[Tuple[str, str], _EventPushSummary] = {} + summaries: Dict[Tuple[str, str, str], _EventPushSummary] = {} for row in txn: - summaries[(row[0], row[1])] = _EventPushSummary( - unread_count=row[2], - stream_ordering=row[3], + summaries[(row[0], row[1], row[2])] = _EventPushSummary( + unread_count=row[3], + stream_ordering=row[4], notif_count=0, ) @@ -1249,34 +1287,35 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) for row in txn: - if (row[0], row[1]) in summaries: - summaries[(row[0], row[1])].notif_count = row[2] + if (row[0], row[1], row[2]) in summaries: + summaries[(row[0], row[1], row[2])].notif_count = row[3] else: # Because the rules on notifying are different than the rules on marking # a message unread, we might end up with messages that notify but aren't # marked unread, so we might not have a summary for this (user, room) # tuple to complete. - summaries[(row[0], row[1])] = _EventPushSummary( + summaries[(row[0], row[1], row[2])] = _EventPushSummary( unread_count=0, - stream_ordering=row[3], - notif_count=row[2], + stream_ordering=row[4], + notif_count=row[3], ) logger.info("Rotating notifications, handling %d rows", len(summaries)) - # TODO(threads): Update on a per-thread basis. self.db_pool.simple_upsert_many_txn( txn, table="event_push_summary", - key_names=("user_id", "room_id"), - key_values=[(user_id, room_id) for user_id, room_id in summaries], - value_names=("notif_count", "unread_count", "stream_ordering", "thread_id"), + key_names=("user_id", "room_id", "thread_id"), + key_values=[ + (user_id, room_id, thread_id) + for user_id, room_id, thread_id in summaries + ], + value_names=("notif_count", "unread_count", "stream_ordering"), value_values=[ ( summary.notif_count, summary.unread_count, summary.stream_ordering, - "main", ) for summary in summaries.values() ], @@ -1288,7 +1327,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas ) async def _remove_old_push_actions_that_have_rotated(self) -> None: - """Clear out old push actions that have been summarised.""" + """ + Clear out old push actions that have been summarised (and are older than + 1 day ago). + """ # We want to clear out anything that is older than a day that *has* already # been rotated. diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 4a5c947699..19dbf2da7f 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -90,9 +90,9 @@ Changes in SCHEMA_VERSION = 73; SCHEMA_COMPAT_VERSION = ( - # The groups tables are no longer accessible, so synapses with SCHEMA_VERSION < 72 - # could break. - 72 + # The threads_id column must exist for event_push_actions, event_push_summary, + # receipts_linearized, and receipts_graph. + 73 ) """Limit on how far the synapse codebase can be rolled back without breaking db compat diff --git a/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql b/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql new file mode 100644 index 0000000000..0ffde9bbeb --- /dev/null +++ b/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql @@ -0,0 +1,29 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Forces the background updates from 06thread_notifications.sql to run in the +-- foreground as code will now require those to be "done". + +DELETE FROM background_updates WHERE update_name = 'event_push_backfill_thread_id'; + +-- Overwrite any null thread_id columns. +UPDATE event_push_actions_staging SET thread_id = 'main' WHERE thread_id IS NULL; +UPDATE event_push_actions SET thread_id = 'main' WHERE thread_id IS NULL; +UPDATE event_push_summary SET thread_id = 'main' WHERE thread_id IS NULL; + +-- Do not run the event_push_summary_unique_index job if it is pending; the +-- thread_id field will be made required. +DELETE FROM background_updates WHERE update_name = 'event_push_summary_unique_index'; +DROP INDEX IF EXISTS event_push_summary_unique_index; diff --git a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres new file mode 100644 index 0000000000..33674f8c62 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres @@ -0,0 +1,19 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The columns can now be made non-nullable. +ALTER TABLE event_push_actions_staging ALTER COLUMN thread_id SET NOT NULL; +ALTER TABLE event_push_actions ALTER COLUMN thread_id SET NOT NULL; +ALTER TABLE event_push_summary ALTER COLUMN thread_id SET NOT NULL; diff --git a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite new file mode 100644 index 0000000000..5322ad77a4 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite @@ -0,0 +1,101 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- SQLite doesn't support modifying columns to an existing table, so it must +-- be recreated. + +-- Create the new tables. +CREATE TABLE event_push_actions_staging_new ( + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + actions TEXT NOT NULL, + notif SMALLINT NOT NULL, + highlight SMALLINT NOT NULL, + unread SMALLINT, + thread_id TEXT NOT NULL, + inserted_ts BIGINT +); + +CREATE TABLE event_push_actions_new ( + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + profile_tag VARCHAR(32), + actions TEXT NOT NULL, + topological_ordering BIGINT, + stream_ordering BIGINT, + notif SMALLINT, + highlight SMALLINT, + unread SMALLINT, + thread_id TEXT NOT NULL, + CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag) +); + +CREATE TABLE event_push_summary_new ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notif_count BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL, + unread_count BIGINT, + last_receipt_stream_ordering BIGINT, + thread_id TEXT NOT NULL +); + +-- Swap the indexes. +DROP INDEX IF EXISTS event_push_actions_staging_id; +CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging_new(event_id); + +DROP INDEX IF EXISTS event_push_actions_room_id_user_id; +DROP INDEX IF EXISTS event_push_actions_rm_tokens; +DROP INDEX IF EXISTS event_push_actions_stream_ordering; +DROP INDEX IF EXISTS event_push_actions_u_highlight; +DROP INDEX IF EXISTS event_push_actions_highlights_index; +CREATE INDEX event_push_actions_room_id_user_id on event_push_actions_new(room_id, user_id); +CREATE INDEX event_push_actions_rm_tokens on event_push_actions_new( user_id, room_id, topological_ordering, stream_ordering ); +CREATE INDEX event_push_actions_stream_ordering on event_push_actions_new( stream_ordering, user_id ); +CREATE INDEX event_push_actions_u_highlight ON event_push_actions_new (user_id, stream_ordering); +CREATE INDEX event_push_actions_highlights_index ON event_push_actions_new (user_id, room_id, topological_ordering, stream_ordering); + +-- Copy the data. +INSERT INTO event_push_actions_staging_new (event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts) + SELECT event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts + FROM event_push_actions_staging; + +INSERT INTO event_push_actions_new (room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id) + SELECT room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id + FROM event_push_actions; + +INSERT INTO event_push_summary_new (user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id) + SELECT user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id + FROM event_push_summary; + +-- Drop the old tables. +DROP TABLE event_push_actions_staging; +DROP TABLE event_push_actions; +DROP TABLE event_push_summary; + +-- Rename the tables. +ALTER TABLE event_push_actions_staging_new RENAME TO event_push_actions_staging; +ALTER TABLE event_push_actions_new RENAME TO event_push_actions; +ALTER TABLE event_push_summary_new RENAME TO event_push_summary; + +-- Re-run background updates from 72/02event_push_actions_index.sql and +-- 72/06thread_notifications.sql. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7307, 'event_push_summary_unique_index2', '{}') + ON CONFLICT (update_name) DO NOTHING; +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7307, 'event_push_actions_stream_highlight_index', '{}') + ON CONFLICT (update_name) DO NOTHING; diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index efd92793c0..d42e36cdf1 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -22,7 +22,10 @@ from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.storage.databases.main.event_push_actions import ( + NotifCounts, + RoomNotifCounts, +) from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser from synapse.types import PersistedEventPosition @@ -178,7 +181,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=0, unread_count=0, notify_count=0), + RoomNotifCounts( + NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {} + ), ) self.persist( @@ -191,7 +196,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=0, unread_count=0, notify_count=1), + RoomNotifCounts( + NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {} + ), ) self.persist( @@ -206,7 +213,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=1, unread_count=0, notify_count=2), + RoomNotifCounts( + NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {} + ), ) def test_get_rooms_for_user_with_stream_ordering(self): diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 473c965e19..89f986ac34 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Optional, Tuple from twisted.test.proto_helpers import MemoryReactor @@ -20,6 +20,7 @@ from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.types import JsonDict from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -133,13 +134,14 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): ) ) self.assertEqual( - counts, + counts.main_timeline, NotifCounts( notify_count=noitf_count, unread_count=0, highlight_count=highlight_count, ), ) + self.assertEqual(counts.threads, {}) def _create_event(highlight: bool = False) -> str: result = self.helper.send_event( @@ -186,6 +188,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _assert_counts(0, 0) _create_event() + _assert_counts(1, 0) _rotate() _assert_counts(1, 0) @@ -236,6 +239,168 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _rotate() _assert_counts(0, 0) + def test_count_aggregation_threads(self) -> None: + """ + This is essentially the same test as test_count_aggregation, but adds + events to the main timeline and to a thread. + """ + + user_id, token, _, other_token, room_id = self._create_users_and_room() + thread_id: str + + last_event_id: str + + def _assert_counts( + noitf_count: int, + highlight_count: int, + thread_notif_count: int, + thread_highlight_count: int, + ) -> None: + counts = self.get_success( + self.store.db_pool.runInteraction( + "get-unread-counts", + self.store._get_unread_counts_by_receipt_txn, + room_id, + user_id, + ) + ) + self.assertEqual( + counts.main_timeline, + NotifCounts( + notify_count=noitf_count, + unread_count=0, + highlight_count=highlight_count, + ), + ) + if thread_notif_count or thread_highlight_count: + self.assertEqual( + counts.threads, + { + thread_id: NotifCounts( + notify_count=thread_notif_count, + unread_count=0, + highlight_count=thread_highlight_count, + ), + }, + ) + else: + self.assertEqual(counts.threads, {}) + + def _create_event( + highlight: bool = False, thread_id: Optional[str] = None + ) -> str: + content: JsonDict = { + "msgtype": "m.text", + "body": user_id if highlight else "msg", + } + if thread_id: + content["m.relates_to"] = { + "rel_type": "m.thread", + "event_id": thread_id, + } + + result = self.helper.send_event( + room_id, + type="m.room.message", + content=content, + tok=other_token, + ) + nonlocal last_event_id + last_event_id = result["event_id"] + return last_event_id + + def _rotate() -> None: + self.get_success(self.store._rotate_notifs()) + + def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None: + self.get_success( + self.store.insert_receipt( + room_id, + "m.read", + user_id=user_id, + event_ids=[event_id], + thread_id=thread_id, + data={}, + ) + ) + + _assert_counts(0, 0, 0, 0) + thread_id = _create_event() + _assert_counts(1, 0, 0, 0) + _rotate() + _assert_counts(1, 0, 0, 0) + + _create_event(thread_id=thread_id) + _assert_counts(1, 0, 1, 0) + _rotate() + _assert_counts(1, 0, 1, 0) + + _create_event() + _assert_counts(2, 0, 1, 0) + _rotate() + _assert_counts(2, 0, 1, 0) + + event_id = _create_event(thread_id=thread_id) + _assert_counts(2, 0, 2, 0) + _rotate() + _assert_counts(2, 0, 2, 0) + + _create_event() + _create_event(thread_id=thread_id) + _mark_read(event_id) + _assert_counts(1, 0, 1, 0) + + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + + _create_event() + _create_event(thread_id=thread_id) + _assert_counts(1, 0, 1, 0) + _rotate() + _assert_counts(1, 0, 1, 0) + + # Delete old event push actions, this should not affect the (summarised) count. + self.get_success(self.store._remove_old_push_actions_that_have_rotated()) + _assert_counts(1, 0, 1, 0) + + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + + _create_event(True) + _assert_counts(1, 1, 0, 0) + _rotate() + _assert_counts(1, 1, 0, 0) + + event_id = _create_event(True, thread_id) + _assert_counts(1, 1, 1, 1) + _rotate() + _assert_counts(1, 1, 1, 1) + + # Check that adding another notification and rotating after highlight + # works. + _create_event() + _rotate() + _assert_counts(2, 1, 1, 1) + + _create_event(thread_id=thread_id) + _rotate() + _assert_counts(2, 1, 2, 1) + + # Check that sending read receipts at different points results in the + # right counts. + _mark_read(event_id) + _assert_counts(1, 0, 1, 0) + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + + _create_event(True) + _create_event(True, thread_id) + _assert_counts(1, 1, 1, 1) + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + _rotate() + _assert_counts(0, 0, 0, 0) + def test_find_first_stream_ordering_after_ts(self) -> None: def add_event(so: int, ts: int) -> None: self.get_success( -- cgit 1.5.1 From d8663f5e6358f8eaeda9a3f923fae720a140ca4d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Oct 2022 10:21:16 -0400 Subject: Advertise supporting version 1.3 of the Matrix spec. (#14032) Now that all features / changes in 1.3 are supported in Synapse. --- changelog.d/14032.feature | 1 + synapse/rest/client/versions.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/14032.feature (limited to 'synapse/rest/client/versions.py') diff --git a/changelog.d/14032.feature b/changelog.d/14032.feature new file mode 100644 index 0000000000..bb221d3ca6 --- /dev/null +++ b/changelog.d/14032.feature @@ -0,0 +1 @@ +Advertise Matrix 1.3 support on `/_matrix/client/versions`. diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 280d306483..18ed313b5c 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -75,6 +75,7 @@ class VersionsRestServlet(RestServlet): "r0.6.1", "v1.1", "v1.2", + "v1.3", ], # as per MSC1497: "unstable_features": { -- cgit 1.5.1 From 66a785733458d0b5801097caff53624e202a91b4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 7 Oct 2022 09:26:40 -0400 Subject: Use stable identifiers for MSC3771 & MSC3773. (#14050) These are both part of Matrix 1.4 which has now been released. For now, support both the unstable and stable identifiers. --- changelog.d/13776.feature | 2 +- changelog.d/13824.feature | 2 +- changelog.d/13877.feature | 2 +- changelog.d/13878.feature | 2 +- changelog.d/14050.feature | 1 + synapse/api/filtering.py | 13 +++++++---- synapse/config/experimental.py | 2 -- synapse/handlers/receipts.py | 11 ++++------ synapse/handlers/sync.py | 7 +----- synapse/rest/client/receipts.py | 48 ++++++++++++++++++++--------------------- synapse/rest/client/sync.py | 9 +++++--- synapse/rest/client/versions.py | 2 +- 12 files changed, 49 insertions(+), 52 deletions(-) create mode 100644 changelog.d/14050.feature (limited to 'synapse/rest/client/versions.py') diff --git a/changelog.d/13776.feature b/changelog.d/13776.feature index 22bce125ce..5d0ae16e13 100644 --- a/changelog.d/13776.feature +++ b/changelog.d/13776.feature @@ -1 +1 @@ -Experimental support for thread-specific notifications ([MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/changelog.d/13824.feature b/changelog.d/13824.feature index d0cb902dff..5d0ae16e13 100644 --- a/changelog.d/13824.feature +++ b/changelog.d/13824.feature @@ -1 +1 @@ -Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/changelog.d/13877.feature b/changelog.d/13877.feature index d0cb902dff..5d0ae16e13 100644 --- a/changelog.d/13877.feature +++ b/changelog.d/13877.feature @@ -1 +1 @@ -Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/changelog.d/13878.feature b/changelog.d/13878.feature index d0cb902dff..5d0ae16e13 100644 --- a/changelog.d/13878.feature +++ b/changelog.d/13878.feature @@ -1 +1 @@ -Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)). +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/changelog.d/14050.feature b/changelog.d/14050.feature new file mode 100644 index 0000000000..5d0ae16e13 --- /dev/null +++ b/changelog.d/14050.feature @@ -0,0 +1 @@ +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index c6e44dcf82..cc31cf8cc7 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -84,6 +84,7 @@ ROOM_EVENT_FILTER_SCHEMA = { "contains_url": {"type": "boolean"}, "lazy_load_members": {"type": "boolean"}, "include_redundant_members": {"type": "boolean"}, + "unread_thread_notifications": {"type": "boolean"}, "org.matrix.msc3773.unread_thread_notifications": {"type": "boolean"}, # Include or exclude events with the provided labels. # cf https://github.com/matrix-org/matrix-doc/pull/2326 @@ -308,12 +309,16 @@ class Filter: self.include_redundant_members = filter_json.get( "include_redundant_members", False ) - if hs.config.experimental.msc3773_enabled: - self.unread_thread_notifications: bool = filter_json.get( + self.unread_thread_notifications: bool = filter_json.get( + "unread_thread_notifications", False + ) + if ( + not self.unread_thread_notifications + and hs.config.experimental.msc3773_enabled + ): + self.unread_thread_notifications = filter_json.get( "org.matrix.msc3773.unread_thread_notifications", False ) - else: - self.unread_thread_notifications = False self.types = filter_json.get("types", None) self.not_types = filter_json.get("not_types", []) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 6503ce6e34..c35301207a 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -95,8 +95,6 @@ class ExperimentalConfig(Config): # MSC2815 (allow room moderators to view redacted event content) self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False) - # MSC3771: Thread read receipts - self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False) # MSC3772: A push rule for mutual relations. self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) # MSC3773: Thread notifications diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 4768a34c07..4a7ec9e426 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -63,8 +63,6 @@ class ReceiptsHandler: self.clock = self.hs.get_clock() self.state = hs.get_state_handler() - self._msc3771_enabled = hs.config.experimental.msc3771_enabled - async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None: """Called when we receive an EDU of type m.receipt from a remote HS.""" receipts = [] @@ -96,11 +94,10 @@ class ReceiptsHandler: # Check if these receipts apply to a thread. thread_id = None data = user_values.get("data", {}) - if self._msc3771_enabled and isinstance(data, dict): - thread_id = data.get("thread_id") - # If the thread ID is invalid, consider it missing. - if not isinstance(thread_id, str): - thread_id = None + thread_id = data.get("thread_id") + # If the thread ID is invalid, consider it missing. + if not isinstance(thread_id, str): + thread_id = None receipts.append( ReadReceipt( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0f684857ca..1db5d68021 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -279,8 +279,6 @@ class SyncHandler: self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync - self._msc3773_enabled = hs.config.experimental.msc3773_enabled - async def wait_for_sync_for_user( self, requester: Requester, @@ -2412,10 +2410,7 @@ class SyncHandler: unread_count = notifs.main_timeline.unread_count # Check the sync configuration. - if ( - self._msc3773_enabled - and sync_config.filter_collection.unread_thread_notifications() - ): + if sync_config.filter_collection.unread_thread_notifications(): # And add info for each thread. room_sync.unread_thread_notifications = { thread_id: { diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 287dfdd69e..14dec7ac4e 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -50,7 +50,6 @@ class ReceiptRestServlet(RestServlet): ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ, } - self._msc3771_enabled = hs.config.experimental.msc3771_enabled async def on_POST( self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str @@ -67,30 +66,29 @@ class ReceiptRestServlet(RestServlet): # Pull the thread ID, if one exists. thread_id = None - if self._msc3771_enabled: - if "thread_id" in body: - thread_id = body.get("thread_id") - if not thread_id or not isinstance(thread_id, str): - raise SynapseError( - 400, - "thread_id field must be a non-empty string", - Codes.INVALID_PARAM, - ) - - if receipt_type == ReceiptTypes.FULLY_READ: - raise SynapseError( - 400, - f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.", - Codes.INVALID_PARAM, - ) - - # Ensure the event ID roughly correlates to the thread ID. - if thread_id != await self._main_store.get_thread_id(event_id): - raise SynapseError( - 400, - f"event_id {event_id} is not related to thread {thread_id}", - Codes.INVALID_PARAM, - ) + if "thread_id" in body: + thread_id = body.get("thread_id") + if not thread_id or not isinstance(thread_id, str): + raise SynapseError( + 400, + "thread_id field must be a non-empty string", + Codes.INVALID_PARAM, + ) + + if receipt_type == ReceiptTypes.FULLY_READ: + raise SynapseError( + 400, + f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.", + Codes.INVALID_PARAM, + ) + + # Ensure the event ID roughly correlates to the thread ID. + if thread_id != await self._main_store.get_thread_id(event_id): + raise SynapseError( + 400, + f"event_id {event_id} is not related to thread {thread_id}", + Codes.INVALID_PARAM, + ) await self.presence_handler.bump_presence_active_time(requester.user) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index f1c23d68e5..8a16459105 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -100,6 +100,7 @@ class SyncRestServlet(RestServlet): self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() self._msc2654_enabled = hs.config.experimental.msc2654_enabled + self._msc3773_enabled = hs.config.experimental.msc3773_enabled async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: # This will always be set by the time Twisted calls us. @@ -510,9 +511,11 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications if room.unread_thread_notifications: - result[ - "org.matrix.msc3773.unread_thread_notifications" - ] = room.unread_thread_notifications + result["unread_thread_notifications"] = room.unread_thread_notifications + if self._msc3773_enabled: + result[ + "org.matrix.msc3773.unread_thread_notifications" + ] = room.unread_thread_notifications result["summary"] = room.summary if self._msc2654_enabled: result["org.matrix.msc2654.unread_count"] = room.unread_count diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 18ed313b5c..d1d2e5f7e3 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -105,7 +105,7 @@ class VersionsRestServlet(RestServlet): # Adds support for thread relations, per MSC3440. "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above # Support for thread read receipts & notification counts. - "org.matrix.msc3771": self.config.experimental.msc3771_enabled, + "org.matrix.msc3771": True, "org.matrix.msc3773": self.config.experimental.msc3773_enabled, # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, -- cgit 1.5.1 From 022f25b3090f7f3a494cecb398bfdbbc2488c2bf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 09:21:55 -0400 Subject: Advertise support for Matrix 1.4. (#14184) All features / changes in Matrix 1.4 are now supported in Synapse. --- changelog.d/14032.feature | 2 +- changelog.d/14184.feature | 1 + synapse/rest/client/versions.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/14184.feature (limited to 'synapse/rest/client/versions.py') diff --git a/changelog.d/14032.feature b/changelog.d/14032.feature index bb221d3ca6..016c704227 100644 --- a/changelog.d/14032.feature +++ b/changelog.d/14032.feature @@ -1 +1 @@ -Advertise Matrix 1.3 support on `/_matrix/client/versions`. +Advertise support for Matrix 1.3 and 1.4 on `/_matrix/client/versions`. diff --git a/changelog.d/14184.feature b/changelog.d/14184.feature new file mode 100644 index 0000000000..016c704227 --- /dev/null +++ b/changelog.d/14184.feature @@ -0,0 +1 @@ +Advertise support for Matrix 1.3 and 1.4 on `/_matrix/client/versions`. diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index d1d2e5f7e3..4e1fd2bbe7 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -76,6 +76,7 @@ class VersionsRestServlet(RestServlet): "v1.1", "v1.2", "v1.3", + "v1.4", ], # as per MSC1497: "unstable_features": { -- cgit 1.5.1 From 4283bd1cf9c3da2157c3642a7c4f105e9fac2636 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 17 Oct 2022 11:32:11 -0400 Subject: Support filtering the /messages API by relation type (MSC3874). (#14148) Gated behind an experimental configuration flag. --- changelog.d/14148.feature | 1 + synapse/api/filtering.py | 27 +++++- synapse/config/experimental.py | 3 + synapse/rest/client/versions.py | 2 + synapse/storage/databases/main/stream.py | 29 ++++++- tests/api/test_filtering.py | 63 +++++++++++++- tests/rest/client/test_relations.py | 1 - tests/rest/client/test_rooms.py | 145 ++----------------------------- tests/storage/test_stream.py | 118 ++++++++++++++++++------- 9 files changed, 212 insertions(+), 177 deletions(-) create mode 100644 changelog.d/14148.feature (limited to 'synapse/rest/client/versions.py') diff --git a/changelog.d/14148.feature b/changelog.d/14148.feature new file mode 100644 index 0000000000..951d0cac80 --- /dev/null +++ b/changelog.d/14148.feature @@ -0,0 +1 @@ +Experimental support for [MSC3874](https://github.com/matrix-org/matrix-spec-proposals/pull/3874). diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index cc31cf8cc7..26be377d03 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -36,7 +36,7 @@ from jsonschema import FormatChecker from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState -from synapse.events import EventBase +from synapse.events import EventBase, relation_from_event from synapse.types import JsonDict, RoomID, UserID if TYPE_CHECKING: @@ -53,6 +53,12 @@ FILTER_SCHEMA = { # check types are valid event types "types": {"type": "array", "items": {"type": "string"}}, "not_types": {"type": "array", "items": {"type": "string"}}, + # MSC3874, filtering /messages. + "org.matrix.msc3874.rel_types": {"type": "array", "items": {"type": "string"}}, + "org.matrix.msc3874.not_rel_types": { + "type": "array", + "items": {"type": "string"}, + }, }, } @@ -334,8 +340,15 @@ class Filter: self.labels = filter_json.get("org.matrix.labels", None) self.not_labels = filter_json.get("org.matrix.not_labels", []) - self.related_by_senders = self.filter_json.get("related_by_senders", None) - self.related_by_rel_types = self.filter_json.get("related_by_rel_types", None) + self.related_by_senders = filter_json.get("related_by_senders", None) + self.related_by_rel_types = filter_json.get("related_by_rel_types", None) + + # For compatibility with _check_fields. + self.rel_types = None + self.not_rel_types = [] + if hs.config.experimental.msc3874_enabled: + self.rel_types = filter_json.get("org.matrix.msc3874.rel_types", None) + self.not_rel_types = filter_json.get("org.matrix.msc3874.not_rel_types", []) def filters_all_types(self) -> bool: return "*" in self.not_types @@ -386,11 +399,19 @@ class Filter: # check if there is a string url field in the content for filtering purposes labels = content.get(EventContentFields.LABELS, []) + # Check if the event has a relation. + rel_type = None + if isinstance(event, EventBase): + relation = relation_from_event(event) + if relation: + rel_type = relation.rel_type + field_matchers = { "rooms": lambda v: room_id == v, "senders": lambda v: sender == v, "types": lambda v: _matches_wildcard(ev_type, v), "labels": lambda v: v in labels, + "rel_types": lambda v: rel_type == v, } result = self._check_fields(field_matchers) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f44655516e..f9a49451d8 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -117,3 +117,6 @@ class ExperimentalConfig(Config): self.msc3882_token_timeout = self.parse_duration( experimental.get("msc3882_token_timeout", "5m") ) + + # MSC3874: Filtering /messages with rel_types / not_rel_types. + self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 4e1fd2bbe7..4b87ee978a 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -114,6 +114,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3882": self.config.experimental.msc3882_enabled, # Adds support for remotely enabling/disabling pushers, as per MSC3881 "org.matrix.msc3881": self.config.experimental.msc3881_enabled, + # Adds support for filtering /messages by event relation. + "org.matrix.msc3874": self.config.experimental.msc3874_enabled, }, }, ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 5baffbfe55..09ce855aa8 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -357,6 +357,24 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: ) args.extend(event_filter.related_by_rel_types) + if event_filter.rel_types: + clauses.append( + "(%s)" + % " OR ".join( + "event_relation.relation_type = ?" for _ in event_filter.rel_types + ) + ) + args.extend(event_filter.rel_types) + + if event_filter.not_rel_types: + clauses.append( + "((%s) OR event_relation.relation_type IS NULL)" + % " AND ".join( + "event_relation.relation_type != ?" for _ in event_filter.not_rel_types + ) + ) + args.extend(event_filter.not_rel_types) + return " AND ".join(clauses), args @@ -1278,8 +1296,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # Multiple labels could cause the same event to appear multiple times. needs_distinct = True - # If there is a filter on relation_senders and relation_types join to the - # relations table. + # If there is a relation_senders and relation_types filter join to the + # relations table to get events related to the current event. if event_filter and ( event_filter.related_by_senders or event_filter.related_by_rel_types ): @@ -1294,6 +1312,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ + # If there is a not_rel_types filter join to the relations table to get + # the event's relation information. + if event_filter and (event_filter.rel_types or event_filter.not_rel_types): + join_clause += """ + LEFT JOIN event_relations AS event_relation USING (event_id) + """ + if needs_distinct: select_keywords += " DISTINCT" diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index a269c477fb..a82c4eed86 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -35,6 +35,8 @@ def MockEvent(**kwargs): kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: kwargs["type"] = "fake_type" + if "content" not in kwargs: + kwargs["content"] = {} return make_event_from_dict(kwargs) @@ -357,6 +359,66 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertTrue(Filter(self.hs, definition)._check(event)) + @unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) + def test_filter_rel_type(self): + definition = {"org.matrix.msc3874.rel_types": ["m.thread"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.reference"}}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.thread"}}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + + @unittest.override_config({"experimental_features": {"msc3874_enabled": True}}) + def test_filter_not_rel_type(self): + definition = {"org.matrix.msc3874.not_rel_types": ["m.thread"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.thread"}}, + ) + + self.assertFalse(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.reference"}}, + ) + + self.assertTrue(Filter(self.hs, definition)._check(event)) + def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} filter_id = self.get_success( @@ -456,7 +518,6 @@ class FilteringTestCase(unittest.HomeserverTestCase): self.assertEqual(filtered_room_ids, ["!allowed:example.com"]) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_filter_relations(self): events = [ # An event without a relation. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index f5c1070b2c..ddf315b894 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1677,7 +1677,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_redact_parent_thread(self) -> None: """ Test that thread replies are still available when the root event is redacted. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 3612ebe7b9..71b1637be8 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -35,7 +35,6 @@ from synapse.api.constants import ( EventTypes, Membership, PublicRoomsFilterFields, - RelationTypes, RoomTypes, ) from synapse.api.errors import Codes, HttpResponseException @@ -50,6 +49,7 @@ from synapse.util.stringutils import random_string from tests import unittest from tests.http.server._base import make_request_with_cancellation_test +from tests.storage.test_stream import PaginationTestCase from tests.test_utils import make_awaitable PATH_PREFIX = b"/_matrix/client/api/v1" @@ -2915,149 +2915,20 @@ class LabelsTestCase(unittest.HomeserverTestCase): return event_id -class RelationsTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def default_config(self) -> Dict[str, Any]: - config = super().default_config() - config["experimental_features"] = {"msc3440_enabled": True} - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.user_id = self.register_user("test", "test") - self.tok = self.login("test", "test") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - self.second_user_id = self.register_user("second", "test") - self.second_tok = self.login("second", "test") - self.helper.join( - room=self.room_id, user=self.second_user_id, tok=self.second_tok - ) - - self.third_user_id = self.register_user("third", "test") - self.third_tok = self.login("third", "test") - self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) - - # An initial event with a relation from second user. - res = self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "Message 1"}, - tok=self.tok, - ) - self.event_id_1 = res["event_id"] - self.helper.send_event( - room_id=self.room_id, - type="m.reaction", - content={ - "m.relates_to": { - "rel_type": RelationTypes.ANNOTATION, - "event_id": self.event_id_1, - "key": "👍", - } - }, - tok=self.second_tok, - ) - - # Another event with a relation from third user. - res = self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "Message 2"}, - tok=self.tok, - ) - self.event_id_2 = res["event_id"] - self.helper.send_event( - room_id=self.room_id, - type="m.reaction", - content={ - "m.relates_to": { - "rel_type": RelationTypes.REFERENCE, - "event_id": self.event_id_2, - } - }, - tok=self.third_tok, - ) - - # An event with no relations. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "No relations"}, - tok=self.tok, - ) - - def _filter_messages(self, filter: JsonDict) -> List[JsonDict]: +class RelationsTestCase(PaginationTestCase): + def _filter_messages(self, filter: JsonDict) -> List[str]: """Make a request to /messages with a filter, returns the chunk of events.""" + from_token = self.get_success( + self.from_token.to_string(self.hs.get_datastores().main) + ) channel = self.make_request( "GET", - "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)), + f"/rooms/{self.room_id}/messages?filter={json.dumps(filter)}&dir=f&from={from_token}", access_token=self.tok, ) self.assertEqual(channel.code, HTTPStatus.OK, channel.result) - return channel.json_body["chunk"] - - def test_filter_relation_senders(self) -> None: - # Messages which second user reacted to. - filter = {"related_by_senders": [self.second_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) - - # Messages which third user reacted to. - filter = {"related_by_senders": [self.third_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_2) - - # Messages which either user reacted to. - filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] - ) - - def test_filter_relation_type(self) -> None: - # Messages which have annotations. - filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) - - # Messages which have references. - filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_2) - - # Messages which have either annotations or references. - filter = { - "related_by_rel_types": [ - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - ] - } - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] - ) - - def test_filter_relation_senders_and_type(self) -> None: - # Messages which second user reacted to. - filter = { - "related_by_senders": [self.second_user_id], - "related_by_rel_types": [RelationTypes.ANNOTATION], - } - chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0]["event_id"], self.event_id_1) + return [ev["event_id"] for ev in channel.json_body["chunk"]] class ContextTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 78663a53fe..34fa810cf6 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -16,7 +16,6 @@ from typing import List from synapse.api.constants import EventTypes, RelationTypes from synapse.api.filtering import Filter -from synapse.events import EventBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.types import JsonDict @@ -40,7 +39,7 @@ class PaginationTestCase(HomeserverTestCase): def default_config(self): config = super().default_config() - config["experimental_features"] = {"msc3440_enabled": True} + config["experimental_features"] = {"msc3874_enabled": True} return config def prepare(self, reactor, clock, homeserver): @@ -58,6 +57,11 @@ class PaginationTestCase(HomeserverTestCase): self.third_tok = self.login("third", "test") self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) + # Store a token which is after all the room creation events. + self.from_token = self.get_success( + self.hs.get_event_sources().get_current_token_for_pagination(self.room_id) + ) + # An initial event with a relation from second user. res = self.helper.send_event( room_id=self.room_id, @@ -66,7 +70,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.tok, ) self.event_id_1 = res["event_id"] - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type="m.reaction", content={ @@ -78,6 +82,7 @@ class PaginationTestCase(HomeserverTestCase): }, tok=self.second_tok, ) + self.event_id_annotation = res["event_id"] # Another event with a relation from third user. res = self.helper.send_event( @@ -87,7 +92,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.tok, ) self.event_id_2 = res["event_id"] - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type="m.reaction", content={ @@ -98,68 +103,59 @@ class PaginationTestCase(HomeserverTestCase): }, tok=self.third_tok, ) + self.event_id_reference = res["event_id"] # An event with no relations. - self.helper.send_event( + res = self.helper.send_event( room_id=self.room_id, type=EventTypes.Message, content={"msgtype": "m.text", "body": "No relations"}, tok=self.tok, ) + self.event_id_none = res["event_id"] - def _filter_messages(self, filter: JsonDict) -> List[EventBase]: + def _filter_messages(self, filter: JsonDict) -> List[str]: """Make a request to /messages with a filter, returns the chunk of events.""" - from_token = self.get_success( - self.hs.get_event_sources().get_current_token_for_pagination(self.room_id) - ) - events, next_key = self.get_success( self.hs.get_datastores().main.paginate_room_events( room_id=self.room_id, - from_key=from_token.room_key, + from_key=self.from_token.room_key, to_key=None, - direction="b", + direction="f", limit=10, event_filter=Filter(self.hs, filter), ) ) - return events + return [ev.event_id for ev in events] def test_filter_relation_senders(self): # Messages which second user reacted to. filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) # Messages which third user reacted to. filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_2) + self.assertEqual(chunk, [self.event_id_2]) # Messages which either user reacted to. filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c.event_id for c in chunk], [self.event_id_1, self.event_id_2] - ) + self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) def test_filter_relation_type(self): # Messages which have annotations. filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) # Messages which have references. filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_2) + self.assertEqual(chunk, [self.event_id_2]) # Messages which have either annotations or references. filter = { @@ -169,10 +165,7 @@ class PaginationTestCase(HomeserverTestCase): ] } chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 2, chunk) - self.assertCountEqual( - [c.event_id for c in chunk], [self.event_id_1, self.event_id_2] - ) + self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2]) def test_filter_relation_senders_and_type(self): # Messages which second user reacted to. @@ -181,8 +174,7 @@ class PaginationTestCase(HomeserverTestCase): "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) def test_duplicate_relation(self): """An event should only be returned once if there are multiple relations to it.""" @@ -201,5 +193,65 @@ class PaginationTestCase(HomeserverTestCase): filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) - self.assertEqual(len(chunk), 1, chunk) - self.assertEqual(chunk[0].event_id, self.event_id_1) + self.assertEqual(chunk, [self.event_id_1]) + + def test_filter_rel_types(self) -> None: + # Messages which are annotations. + filter = {"org.matrix.msc3874.rel_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_annotation]) + + # Messages which are references. + filter = {"org.matrix.msc3874.rel_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_reference]) + + # Messages which are either annotations or references. + filter = { + "org.matrix.msc3874.rel_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertCountEqual( + chunk, + [self.event_id_annotation, self.event_id_reference], + ) + + def test_filter_not_rel_types(self) -> None: + # Messages which are not annotations. + filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual( + chunk, + [ + self.event_id_1, + self.event_id_2, + self.event_id_reference, + self.event_id_none, + ], + ) + + # Messages which are not references. + filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual( + chunk, + [ + self.event_id_1, + self.event_id_annotation, + self.event_id_2, + self.event_id_none, + ], + ) + + # Messages which are neither annotations or references. + filter = { + "org.matrix.msc3874.not_rel_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertEqual(chunk, [self.event_id_1, self.event_id_2, self.event_id_none]) -- cgit 1.5.1 From 4eaf3eb840b8cfa78d970216c74fc128495f08a5 Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Tue, 18 Oct 2022 16:52:25 +0100 Subject: Implementation of HTTP 307 response for MSC3886 POST endpoint (#14018) Co-authored-by: reivilibre Co-authored-by: Andrew Morgan --- changelog.d/14018.feature | 1 + synapse/config/experimental.py | 7 +- synapse/config/server.py | 4 ++ synapse/handlers/sso.py | 2 +- synapse/http/server.py | 48 ++++++++++--- synapse/http/site.py | 3 + synapse/rest/__init__.py | 2 + synapse/rest/client/rendezvous.py | 74 +++++++++++++++++++ synapse/rest/client/versions.py | 3 + synapse/rest/key/v2/local_key_resource.py | 4 +- synapse/rest/synapse/client/new_user_consent.py | 3 +- synapse/rest/well_known.py | 3 +- tests/logging/test_terse_json.py | 1 + tests/rest/client/test_rendezvous.py | 45 ++++++++++++ tests/server.py | 8 ++- tests/test_server.py | 94 ++++++++++++++++++------- 16 files changed, 257 insertions(+), 45 deletions(-) create mode 100644 changelog.d/14018.feature create mode 100644 synapse/rest/client/rendezvous.py create mode 100644 tests/rest/client/test_rendezvous.py (limited to 'synapse/rest/client/versions.py') diff --git a/changelog.d/14018.feature b/changelog.d/14018.feature new file mode 100644 index 0000000000..c8454607eb --- /dev/null +++ b/changelog.d/14018.feature @@ -0,0 +1 @@ +Support for redirecting to an implementation of a [MSC3886](https://github.com/matrix-org/matrix-spec-proposals/pull/3886) HTTP rendezvous service. \ No newline at end of file diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f9a49451d8..4009add01d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import attr @@ -120,3 +120,8 @@ class ExperimentalConfig(Config): # MSC3874: Filtering /messages with rel_types / not_rel_types. self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False) + + # MSC3886: Simple client rendezvous capability + self.msc3886_endpoint: Optional[str] = experimental.get( + "msc3886_endpoint", None + ) diff --git a/synapse/config/server.py b/synapse/config/server.py index f2353ce5fb..ec46ca63ad 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -207,6 +207,9 @@ class HttpListenerConfig: additional_resources: Dict[str, dict] = attr.Factory(dict) tag: Optional[str] = None request_id_header: Optional[str] = None + # If true, the listener will return CORS response headers compatible with MSC3886: + # https://github.com/matrix-org/matrix-spec-proposals/pull/3886 + experimental_cors_msc3886: bool = False @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -935,6 +938,7 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig: additional_resources=listener.get("additional_resources", {}), tag=listener.get("tag"), request_id_header=listener.get("request_id_header"), + experimental_cors_msc3886=listener.get("experimental_cors_msc3886", False), ) return ListenerConfig(port, bind_addresses, listener_type, tls, http_config) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index e035677b8a..5943f08e91 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -874,7 +874,7 @@ class SsoHandler: ) async def handle_terms_accepted( - self, request: Request, session_id: str, terms_version: str + self, request: SynapseRequest, session_id: str, terms_version: str ) -> None: """Handle a request to the new-user 'consent' endpoint diff --git a/synapse/http/server.py b/synapse/http/server.py index bcbfac2c9f..b26e34bceb 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -19,6 +19,7 @@ import logging import types import urllib from http import HTTPStatus +from http.client import FOUND from inspect import isawaitable from typing import ( TYPE_CHECKING, @@ -339,7 +340,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): return callback_return - _unrecognised_request_handler(request) + return _unrecognised_request_handler(request) @abc.abstractmethod def _send_response( @@ -598,7 +599,7 @@ class RootRedirect(resource.Resource): class OptionsResource(resource.Resource): """Responds to OPTION requests for itself and all children.""" - def render_OPTIONS(self, request: Request) -> bytes: + def render_OPTIONS(self, request: SynapseRequest) -> bytes: request.setResponseCode(204) request.setHeader(b"Content-Length", b"0") @@ -763,7 +764,7 @@ def respond_with_json( def respond_with_json_bytes( - request: Request, + request: SynapseRequest, code: int, json_bytes: bytes, send_cors: bool = False, @@ -859,7 +860,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None: _ByteProducer(request, bytes_generator) -def set_cors_headers(request: Request) -> None: +def set_cors_headers(request: SynapseRequest) -> None: """Set the CORS headers so that javascript running in a web browsers can use this API @@ -870,10 +871,20 @@ def set_cors_headers(request: Request) -> None: request.setHeader( b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS" ) - request.setHeader( - b"Access-Control-Allow-Headers", - b"X-Requested-With, Content-Type, Authorization, Date", - ) + if request.experimental_cors_msc3886: + request.setHeader( + b"Access-Control-Allow-Headers", + b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match", + ) + request.setHeader( + b"Access-Control-Expose-Headers", + b"ETag, Location, X-Max-Bytes", + ) + else: + request.setHeader( + b"Access-Control-Allow-Headers", + b"X-Requested-With, Content-Type, Authorization, Date", + ) def set_corp_headers(request: Request) -> None: @@ -942,10 +953,25 @@ def set_clickjacking_protection_headers(request: Request) -> None: request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") -def respond_with_redirect(request: Request, url: bytes) -> None: - """Write a 302 response to the request, if it is still alive.""" +def respond_with_redirect( + request: SynapseRequest, url: bytes, statusCode: int = FOUND, cors: bool = False +) -> None: + """ + Write a 302 (or other specified status code) response to the request, if it is still alive. + + Args: + request: The http request to respond to. + url: The URL to redirect to. + statusCode: The HTTP status code to use for the redirect (defaults to 302). + cors: Whether to set CORS headers on the response. + """ logger.debug("Redirect to %s", url.decode("utf-8")) - request.redirect(url) + + if cors: + set_cors_headers(request) + + request.setResponseCode(statusCode) + request.setHeader(b"location", url) finish_request(request) diff --git a/synapse/http/site.py b/synapse/http/site.py index 55a6afce35..3dbd541fed 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -82,6 +82,7 @@ class SynapseRequest(Request): self.reactor = site.reactor self._channel = channel # this is used by the tests self.start_time = 0.0 + self.experimental_cors_msc3886 = site.experimental_cors_msc3886 # The requester, if authenticated. For federation requests this is the # server name, for client requests this is the Requester object. @@ -622,6 +623,8 @@ class SynapseSite(Site): request_id_header = config.http_options.request_id_header + self.experimental_cors_msc3886 = config.http_options.experimental_cors_msc3886 + def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( channel, diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 9a2ab99ede..28542cd774 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.client import ( receipts, register, relations, + rendezvous, report_event, room, room_batch, @@ -132,3 +133,4 @@ class ClientRestResource(JsonResource): # unstable mutual_rooms.register_servlets(hs, client_resource) login_token_request.register_servlets(hs, client_resource) + rendezvous.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/rendezvous.py b/synapse/rest/client/rendezvous.py new file mode 100644 index 0000000000..89176b1ffa --- /dev/null +++ b/synapse/rest/client/rendezvous.py @@ -0,0 +1,74 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from http.client import TEMPORARY_REDIRECT +from typing import TYPE_CHECKING, Optional + +from synapse.http.server import HttpServer, respond_with_redirect +from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class RendezvousServlet(RestServlet): + """ + This is a placeholder implementation of [MSC3886](https://github.com/matrix-org/matrix-spec-proposals/pull/3886) + simple client rendezvous capability that is used by the "Sign in with QR" functionality. + + This implementation only serves as a 307 redirect to a configured server rather than being a full implementation. + + A module that implements the full functionality is available at: https://pypi.org/project/matrix-http-rendezvous-synapse/. + + Request: + + POST /rendezvous HTTP/1.1 + Content-Type: ... + + ... + + Response: + + HTTP/1.1 307 + Location: + """ + + PATTERNS = client_patterns( + "/org.matrix.msc3886/rendezvous$", releases=[], v1=False, unstable=True + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + redirection_target: Optional[str] = hs.config.experimental.msc3886_endpoint + assert ( + redirection_target is not None + ), "Servlet is only registered if there is a redirection target" + self.endpoint = redirection_target.encode("utf-8") + + async def on_POST(self, request: SynapseRequest) -> None: + respond_with_redirect( + request, self.endpoint, statusCode=TEMPORARY_REDIRECT, cors=True + ) + + # PUT, GET and DELETE are not implemented as they should be fulfilled by the redirect target. + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.experimental.msc3886_endpoint is not None: + RendezvousServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 4b87ee978a..9b1b72c68a 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -116,6 +116,9 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3881": self.config.experimental.msc3881_enabled, # Adds support for filtering /messages by event relation. "org.matrix.msc3874": self.config.experimental.msc3874_enabled, + # Adds support for simple HTTP rendezvous as per MSC3886 + "org.matrix.msc3886": self.config.experimental.msc3886_endpoint + is not None, }, }, ) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 0c9f042c84..095993415c 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -20,9 +20,9 @@ from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 from twisted.web.resource import Resource -from twisted.web.server import Request from synapse.http.server import respond_with_json_bytes +from synapse.http.site import SynapseRequest from synapse.types import JsonDict if TYPE_CHECKING: @@ -99,7 +99,7 @@ class LocalKey(Resource): json_object = sign_json(json_object, self.config.server.server_name, key) return json_object - def render_GET(self, request: Request) -> Optional[int]: + def render_GET(self, request: SynapseRequest) -> Optional[int]: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py index 1c1c7b3613..22784157e6 100644 --- a/synapse/rest/synapse/client/new_user_consent.py +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -20,6 +20,7 @@ from synapse.api.errors import SynapseError from synapse.handlers.sso import get_username_mapping_session_cookie_from_request from synapse.http.server import DirectServeHtmlResource, respond_with_html from synapse.http.servlet import parse_string +from synapse.http.site import SynapseRequest from synapse.types import UserID from synapse.util.templates import build_jinja_env @@ -88,7 +89,7 @@ class NewUserConsentResource(DirectServeHtmlResource): html = template.render(template_params) respond_with_html(request, 200, html) - async def _async_render_POST(self, request: Request) -> None: + async def _async_render_POST(self, request: SynapseRequest) -> None: try: session_id = get_username_mapping_session_cookie_from_request(request) except SynapseError as e: diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index 6f7ac54c65..e2174fdfea 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -18,6 +18,7 @@ from twisted.web.resource import Resource from twisted.web.server import Request from synapse.http.server import set_cors_headers +from synapse.http.site import SynapseRequest from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.stringutils import parse_server_name @@ -63,7 +64,7 @@ class ClientWellKnownResource(Resource): Resource.__init__(self) self._well_known_builder = WellKnownBuilder(hs) - def render_GET(self, request: Request) -> bytes: + def render_GET(self, request: SynapseRequest) -> bytes: set_cors_headers(request) r = self._well_known_builder.get_well_known() if not r: diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index 96f399b7ab..0b0d8737c1 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -153,6 +153,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): site.site_tag = "test-site" site.server_version_string = "Server v1" site.reactor = Mock() + site.experimental_cors_msc3886 = False request = SynapseRequest(FakeChannel(site, None), site) # Call requestReceived to finish instantiating the object. request.content = BytesIO() diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py new file mode 100644 index 0000000000..ad00a476e1 --- /dev/null +++ b/tests/rest/client/test_rendezvous.py @@ -0,0 +1,45 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest.client import rendezvous +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.unittest import override_config + +endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous" + + +class RendezvousServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + rendezvous.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.hs = self.setup_test_homeserver() + return self.hs + + def test_disabled(self) -> None: + channel = self.make_request("POST", endpoint, {}, access_token=None) + self.assertEqual(channel.code, 400) + + @override_config({"experimental_features": {"msc3886_endpoint": "/asd"}}) + def test_redirect(self) -> None: + channel = self.make_request("POST", endpoint, {}, access_token=None) + self.assertEqual(channel.code, 307) + self.assertEqual(channel.headers.getRawHeaders("Location"), ["/asd"]) diff --git a/tests/server.py b/tests/server.py index c447d5e4c4..8b1d186219 100644 --- a/tests/server.py +++ b/tests/server.py @@ -266,7 +266,12 @@ class FakeSite: site_tag = "test" access_logger = logging.getLogger("synapse.access.http.fake") - def __init__(self, resource: IResource, reactor: IReactorTime): + def __init__( + self, + resource: IResource, + reactor: IReactorTime, + experimental_cors_msc3886: bool = False, + ): """ Args: @@ -274,6 +279,7 @@ class FakeSite: """ self._resource = resource self.reactor = reactor + self.experimental_cors_msc3886 = experimental_cors_msc3886 def getResourceFor(self, request): return self._resource diff --git a/tests/test_server.py b/tests/test_server.py index 7c66448245..2d9a0257d4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -222,13 +222,22 @@ class OptionsResourceTests(unittest.TestCase): self.resource = OptionsResource() self.resource.putChild(b"res", DummyResource()) - def _make_request(self, method: bytes, path: bytes) -> FakeChannel: + def _make_request( + self, method: bytes, path: bytes, experimental_cors_msc3886: bool = False + ) -> FakeChannel: """Create a request from the method/path and return a channel with the response.""" # Create a site and query for the resource. site = SynapseSite( "test", "site_tag", - parse_listener_def(0, {"type": "http", "port": 0}), + parse_listener_def( + 0, + { + "type": "http", + "port": 0, + "experimental_cors_msc3886": experimental_cors_msc3886, + }, + ), self.resource, "1.0", max_request_body_size=4096, @@ -239,25 +248,58 @@ class OptionsResourceTests(unittest.TestCase): channel = make_request(self.reactor, site, method, path, shorthand=False) return channel + def _check_cors_standard_headers(self, channel: FakeChannel) -> None: + # Ensure the correct CORS headers have been added + # as per https://spec.matrix.org/v1.4/client-server-api/#web-browser-clients + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"), + [b"*"], + "has correct CORS Origin header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Methods"), + [b"GET, HEAD, POST, PUT, DELETE, OPTIONS"], # HEAD isn't in the spec + "has correct CORS Methods header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Headers"), + [b"X-Requested-With, Content-Type, Authorization, Date"], + "has correct CORS Headers header", + ) + + def _check_cors_msc3886_headers(self, channel: FakeChannel) -> None: + # Ensure the correct CORS headers have been added + # as per https://github.com/matrix-org/matrix-spec-proposals/blob/hughns/simple-rendezvous-capability/proposals/3886-simple-rendezvous-capability.md#cors + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"), + [b"*"], + "has correct CORS Origin header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Methods"), + [b"GET, HEAD, POST, PUT, DELETE, OPTIONS"], # HEAD isn't in the spec + "has correct CORS Methods header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Allow-Headers"), + [ + b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match" + ], + "has correct CORS Headers header", + ) + self.assertEqual( + channel.headers.getRawHeaders(b"Access-Control-Expose-Headers"), + [b"ETag, Location, X-Max-Bytes"], + "has correct CORS Expose Headers header", + ) + def test_unknown_options_request(self) -> None: """An OPTIONS requests to an unknown URL still returns 204 No Content.""" channel = self._make_request(b"OPTIONS", b"/foo/") self.assertEqual(channel.code, 204) self.assertNotIn("body", channel.result) - # Ensure the correct CORS headers have been added - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Origin"), - "has CORS Origin header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Methods"), - "has CORS Methods header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Headers"), - "has CORS Headers header", - ) + self._check_cors_standard_headers(channel) def test_known_options_request(self) -> None: """An OPTIONS requests to an known URL still returns 204 No Content.""" @@ -265,19 +307,17 @@ class OptionsResourceTests(unittest.TestCase): self.assertEqual(channel.code, 204) self.assertNotIn("body", channel.result) - # Ensure the correct CORS headers have been added - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Origin"), - "has CORS Origin header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Methods"), - "has CORS Methods header", - ) - self.assertTrue( - channel.headers.hasHeader(b"Access-Control-Allow-Headers"), - "has CORS Headers header", + self._check_cors_standard_headers(channel) + + def test_known_options_request_msc3886(self) -> None: + """An OPTIONS requests to an known URL still returns 204 No Content.""" + channel = self._make_request( + b"OPTIONS", b"/res/", experimental_cors_msc3886=True ) + self.assertEqual(channel.code, 204) + self.assertNotIn("body", channel.result) + + self._check_cors_msc3886_headers(channel) def test_unknown_request(self) -> None: """A non-OPTIONS request to an unknown URL should 404.""" -- cgit 1.5.1 From 86c5a710d8b4212f8a8a668d7d4a79c0bb371508 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 3 Nov 2022 16:21:31 +0000 Subject: Implement MSC3912: Relation-based redactions (#14260) Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/14260.feature | 1 + synapse/api/constants.py | 2 + synapse/config/experimental.py | 3 + synapse/handlers/message.py | 47 ++++- synapse/handlers/relations.py | 56 +++++- synapse/rest/client/room.py | 57 ++++-- synapse/rest/client/versions.py | 2 + synapse/storage/databases/main/relations.py | 36 ++++ tests/rest/client/test_redactions.py | 273 +++++++++++++++++++++++++++- tests/rest/client/utils.py | 37 ++++ 10 files changed, 486 insertions(+), 28 deletions(-) create mode 100644 changelog.d/14260.feature (limited to 'synapse/rest/client/versions.py') diff --git a/changelog.d/14260.feature b/changelog.d/14260.feature new file mode 100644 index 0000000000..102dc7b3e0 --- /dev/null +++ b/changelog.d/14260.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3912](https://github.com/matrix-org/matrix-spec-proposals/pull/3912): Relation-based redactions. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 44c5ffc6a5..bc04a0755b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -125,6 +125,8 @@ class EventTypes: MSC2716_BATCH: Final = "org.matrix.msc2716.batch" MSC2716_MARKER: Final = "org.matrix.msc2716.marker" + Reaction: Final = "m.reaction" + class ToDeviceEventTypes: RoomKeyRequest: Final = "m.room_key_request" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d9bdd66d55..d4b71d1673 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -128,3 +128,6 @@ class ExperimentalConfig(Config): self.msc3886_endpoint: Optional[str] = experimental.get( "msc3886_endpoint", None ) + + # MSC3912: Relation-based redactions. + self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 468900a07f..4cf593cfdc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -877,6 +877,36 @@ class EventCreationHandler: return prev_event return None + async def get_event_from_transaction( + self, + requester: Requester, + txn_id: str, + room_id: str, + ) -> Optional[EventBase]: + """For the given transaction ID and room ID, check if there is a matching event. + If so, fetch it and return it. + + Args: + requester: The requester making the request in the context of which we want + to fetch the event. + txn_id: The transaction ID. + room_id: The room ID. + + Returns: + An event if one could be found, None otherwise. + """ + if requester.access_token_id: + existing_event_id = await self.store.get_event_id_from_transaction_id( + room_id, + requester.user.to_string(), + requester.access_token_id, + txn_id, + ) + if existing_event_id: + return await self.store.get_event(existing_event_id) + + return None + async def create_and_send_nonmember_event( self, requester: Requester, @@ -956,18 +986,17 @@ class EventCreationHandler: # extremities to pile up, which in turn leads to state resolution # taking longer. async with self.limiter.queue(event_dict["room_id"]): - if txn_id and requester.access_token_id: - existing_event_id = await self.store.get_event_id_from_transaction_id( - event_dict["room_id"], - requester.user.to_string(), - requester.access_token_id, - txn_id, + if txn_id: + event = await self.get_event_from_transaction( + requester, txn_id, event_dict["room_id"] ) - if existing_event_id: - event = await self.store.get_event(existing_event_id) + if event: # we know it was persisted, so must have a stream ordering assert event.internal_metadata.stream_ordering - return event, event.internal_metadata.stream_ordering + return ( + event, + event.internal_metadata.stream_ordering, + ) event, context = await self.create_event( requester, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 0a0c6d938e..8e71dda970 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tup import attr -from synapse.api.constants import RelationTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace @@ -75,6 +75,7 @@ class RelationsHandler: self._clock = hs.get_clock() self._event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() + self._event_creation_handler = hs.get_event_creation_handler() async def get_relations( self, @@ -205,6 +206,59 @@ class RelationsHandler: return related_events, next_token + async def redact_events_related_to( + self, + requester: Requester, + event_id: str, + initial_redaction_event: EventBase, + relation_types: List[str], + ) -> None: + """Redacts all events related to the given event ID with one of the given + relation types. + + This method is expected to be called when redacting the event referred to by + the given event ID. + + If an event cannot be redacted (e.g. because of insufficient permissions), log + the error and try to redact the next one. + + Args: + requester: The requester to redact events on behalf of. + event_id: The event IDs to look and redact relations of. + initial_redaction_event: The redaction for the event referred to by + event_id. + relation_types: The types of relations to look for. + + Raises: + ShadowBanError if the requester is shadow-banned + """ + related_event_ids = ( + await self._main_store.get_all_relations_for_event_with_types( + event_id, relation_types + ) + ) + + for related_event_id in related_event_ids: + try: + await self._event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": initial_redaction_event.content, + "room_id": initial_redaction_event.room_id, + "sender": requester.user.to_string(), + "redacts": related_event_id, + }, + ratelimit=False, + ) + except SynapseError as e: + logger.warning( + "Failed to redact event %s (related to event %s): %s", + related_event_id, + event_id, + e.msg, + ) + async def get_annotations_for_event( self, event_id: str, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 01e5079963..91cb791139 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -52,6 +52,7 @@ from synapse.http.servlet import ( from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache from synapse.storage.state import StateFilter @@ -1029,6 +1030,8 @@ class RoomRedactEventRestServlet(TransactionRestServlet): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() + self._relation_handler = hs.get_relations_handler() + self._msc3912_enabled = hs.config.experimental.msc3912_enabled def register(self, http_server: HttpServer) -> None: PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" @@ -1045,20 +1048,46 @@ class RoomRedactEventRestServlet(TransactionRestServlet): content = parse_json_object_from_request(request) try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Redaction, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "redacts": event_id, - }, - txn_id=txn_id, - ) + with_relations = None + if self._msc3912_enabled and "org.matrix.msc3912.with_relations" in content: + with_relations = content["org.matrix.msc3912.with_relations"] + del content["org.matrix.msc3912.with_relations"] + + # Check if there's an existing event for this transaction now (even though + # create_and_send_nonmember_event also does it) because, if there's one, + # then we want to skip the call to redact_events_related_to. + event = None + if txn_id: + event = await self.event_creation_handler.get_event_from_transaction( + requester, txn_id, room_id + ) + + if event is None: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "redacts": event_id, + }, + txn_id=txn_id, + ) + + if with_relations: + run_as_background_process( + "redact_related_events", + self._relation_handler.redact_events_related_to, + requester=requester, + event_id=event_id, + initial_redaction_event=event, + relation_types=with_relations, + ) + event_id = event.event_id except ShadowBanError: event_id = "$" + random_string(43) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 9b1b72c68a..180a11ef88 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -119,6 +119,8 @@ class VersionsRestServlet(RestServlet): # Adds support for simple HTTP rendezvous as per MSC3886 "org.matrix.msc3886": self.config.experimental.msc3886_endpoint is not None, + # Adds support for relation-based redactions as per MSC3912. + "org.matrix.msc3912": self.config.experimental.msc3912_enabled, }, }, ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index c022510e76..ca431002c8 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -295,6 +295,42 @@ class RelationsWorkerStore(SQLBaseStore): "get_recent_references_for_event", _get_recent_references_for_event_txn ) + async def get_all_relations_for_event_with_types( + self, + event_id: str, + relation_types: List[str], + ) -> List[str]: + """Get the event IDs of all events that have a relation to the given event with + one of the given relation types. + + Args: + event_id: The event for which to look for related events. + relation_types: The types of relations to look for. + + Returns: + A list of the IDs of the events that relate to the given event with one of + the given relation types. + """ + + def get_all_relation_ids_for_event_with_types_txn( + txn: LoggingTransaction, + ) -> List[str]: + rows = self.db_pool.simple_select_many_txn( + txn=txn, + table="event_relations", + column="relation_type", + iterable=relation_types, + keyvalues={"relates_to_id": event_id}, + retcols=["event_id"], + ) + + return [row["event_id"] for row in rows] + + return await self.db_pool.runInteraction( + desc="get_all_relation_ids_for_event_with_types", + func=get_all_relation_ids_for_event_with_types_txn, + ) + async def event_includes_relation(self, event_id: str) -> bool: """Check if the given event relates to another event. diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index be4c67d68e..5dfe44defb 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -11,17 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, room, sync from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class RedactionsTestCase(HomeserverTestCase): @@ -67,7 +68,12 @@ class RedactionsTestCase(HomeserverTestCase): ) def _redact_event( - self, access_token: str, room_id: str, event_id: str, expect_code: int = 200 + self, + access_token: str, + room_id: str, + event_id: str, + expect_code: int = 200, + with_relations: Optional[List[str]] = None, ) -> JsonDict: """Helper function to send a redaction event. @@ -75,7 +81,13 @@ class RedactionsTestCase(HomeserverTestCase): """ path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) - channel = self.make_request("POST", path, content={}, access_token=access_token) + request_content = {} + if with_relations: + request_content["org.matrix.msc3912.with_relations"] = with_relations + + channel = self.make_request( + "POST", path, request_content, access_token=access_token + ) self.assertEqual(channel.code, expect_code) return channel.json_body @@ -201,3 +213,256 @@ class RedactionsTestCase(HomeserverTestCase): # These should all succeed, even though this would be denied by # the standard message ratelimiter self._redact_event(self.mod_access_token, self.room_id, msg_id) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations(self) -> None: + """Tests that we can redact the relations of an event at the same time as the + event itself. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "hello"}, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send an edit to this root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": " * hello world", + "m.new_content": { + "body": "hello world", + "msgtype": "m.text", + }, + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.REPLACE, + }, + "msgtype": "m.text", + }, + tok=self.mod_access_token, + ) + edit_event_id = res["event_id"] + + # Also send a threaded message whose root is the same as the edit's. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Also send a reaction, again with the same root. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Reaction, + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": root_event_id, + "key": "👍", + } + }, + tok=self.mod_access_token, + ) + reaction_event_id = res["event_id"] + + # Redact the root event, specifying that we also want to delete events that + # relate to it with m.replace. + self._redact_event( + self.mod_access_token, + self.room_id, + root_event_id, + with_relations=[ + RelationTypes.REPLACE, + RelationTypes.THREAD, + ], + ) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the edit got redacted. + event_dict = self.helper.get_event( + self.room_id, edit_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the threaded message got redacted. + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the reaction did not get redacted. + event_dict = self.helper.get_event( + self.room_id, reaction_event_id, self.mod_access_token + ) + self.assertNotIn("redacted_because", event_dict, event_dict) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations_no_perms(self) -> None: + """Tests that, when redacting a message along with its relations, if not all + the related messages can be redacted because of insufficient permissions, the + server still redacts all the ones that can be. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "root", + }, + tok=self.other_access_token, + ) + root_event_id = res["event_id"] + + # Send a first threaded message, this one from the moderator. We do this for the + # first message with the m.thread relation (and not the last one) to ensure + # that, when the server fails to redact it, it doesn't stop there, and it + # instead goes on to redact the other one. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + first_threaded_event_id = res["event_id"] + + # Send a second threaded message, this time from the user who'll perform the + # redaction. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 2", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.other_access_token, + ) + second_threaded_event_id = res["event_id"] + + # Redact the thread's root, and request that all threaded messages are also + # redacted. Send that request from the non-mod user, so that the first threaded + # event cannot be redacted. + self._redact_event( + self.other_access_token, + self.room_id, + root_event_id, + with_relations=[RelationTypes.THREAD], + ) + + # Check that the thread root got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.other_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the last message in the thread got redacted, despite failing to + # redact the one before it. + event_dict = self.helper.get_event( + self.room_id, second_threaded_event_id, self.other_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the message that was sent into the tread by the mod user is not + # redacted. + event_dict = self.helper.get_event( + self.room_id, first_threaded_event_id, self.other_access_token + ) + self.assertIn("body", event_dict["content"], event_dict) + self.assertEqual("message 1", event_dict["content"]["body"]) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations_txn_id_reuse(self) -> None: + """Tests that redacting a message using a transaction ID, then reusing the same + transaction ID but providing an additional list of relations to redact, is + effectively a no-op. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "root", + }, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send a first threaded message. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "I'm in a thread!", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Send a first redaction request which redacts only the root event. + channel = self.make_request( + method="PUT", + path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo", + content={}, + access_token=self.mod_access_token, + ) + self.assertEqual(channel.code, 200) + + # Send a second redaction request which redacts the root event as well as + # threaded messages. + channel = self.make_request( + method="PUT", + path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo", + content={"org.matrix.msc3912.with_relations": [RelationTypes.THREAD]}, + access_token=self.mod_access_token, + ) + self.assertEqual(channel.code, 200) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict) + + # Check that the threaded message didn't get redacted (since that wasn't part of + # the original redaction). + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("body", event_dict["content"], event_dict) + self.assertEqual("I'm in a thread!", event_dict["content"]["body"]) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 706399fae5..8d6f2b6ff9 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -410,6 +410,43 @@ class RestHelper: return channel.json_body + def get_event( + self, + room_id: str, + event_id: str, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + ) -> JsonDict: + """Request a specific event from the server. + + Args: + room_id: the room in which the event was sent. + event_id: the event's ID. + tok: the token to request the event with. + expect_code: the expected HTTP status for the response. + + Returns: + The event as a dict. + """ + path = f"/_matrix/client/v3/rooms/{room_id}/event/{event_id}" + if tok: + path = path + f"?access_token={tok}" + + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + path, + ) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + channel.code, + channel.result["body"], + ) + + return channel.json_body + def _read_write_state( self, room_id: str, -- cgit 1.5.1