From 365e9482fe18b293f55f78e5f5d2d1107a1d95e1 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 8 Dec 2021 14:54:47 +0000 Subject: Use HTTPStatus constants in place of literals in `tests.rest.client.test_auth`. (#11520) --- tests/rest/client/test_auth.py | 134 ++++++++++++++++++++++++++--------------- 1 file changed, 87 insertions(+), 47 deletions(-) (limited to 'tests/rest/client') diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 72bbc87b4a..27cb856b0a 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -85,7 +85,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", "auth/m.login.recaptcha/fallback/web?session=" + session ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) channel = self.make_request( "POST", @@ -104,7 +104,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): """Ensure that fallback auth via a captcha works.""" # Returns a 401 as per the spec channel = self.register( - 401, + HTTPStatus.UNAUTHORIZED, {"username": "user", "type": "m.login.password", "password": "bar"}, ) @@ -116,15 +116,17 @@ class FallbackAuthTests(unittest.HomeserverTestCase): ) # Complete the recaptcha step. - self.recaptcha(session, 200) + self.recaptcha(session, HTTPStatus.OK) # also complete the dummy auth - self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) + self.register( + HTTPStatus.OK, {"auth": {"session": session, "type": "m.login.dummy"}} + ) # Now we should have fulfilled a complete auth flow, including # the recaptcha fallback step, we can then send a # request to the register API with the session in the authdict. - channel = self.register(200, {"auth": {"session": session}}) + channel = self.register(HTTPStatus.OK, {"auth": {"session": session}}) # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") @@ -137,7 +139,8 @@ class FallbackAuthTests(unittest.HomeserverTestCase): # will be used.) # Returns a 401 as per the spec channel = self.register( - 401, {"username": "user", "type": "m.login.password", "password": "bar"} + HTTPStatus.UNAUTHORIZED, + {"username": "user", "type": "m.login.password", "password": "bar"}, ) # Grab the session @@ -231,7 +234,9 @@ class UIAuthTests(unittest.HomeserverTestCase): """ # Attempt to delete this device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # Grab the session session = channel.json_body["session"] @@ -242,7 +247,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -260,14 +265,16 @@ class UIAuthTests(unittest.HomeserverTestCase): UIA - check that still works. """ - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) session = channel.json_body["session"] # Make another request providing the UI auth flow. self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -293,7 +300,9 @@ class UIAuthTests(unittest.HomeserverTestCase): # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_devices(401, {"devices": [self.device_id]}) + channel = self.delete_devices( + HTTPStatus.UNAUTHORIZED, {"devices": [self.device_id]} + ) # Grab the session session = channel.json_body["session"] @@ -303,7 +312,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow, but try to delete the # second device. self.delete_devices( - 200, + HTTPStatus.OK, { "devices": ["dev2"], "auth": { @@ -324,7 +333,9 @@ class UIAuthTests(unittest.HomeserverTestCase): # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # Grab the session session = channel.json_body["session"] @@ -338,7 +349,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, "dev2", - 403, + HTTPStatus.FORBIDDEN, { "auth": { "type": "m.login.password", @@ -361,13 +372,13 @@ class UIAuthTests(unittest.HomeserverTestCase): self.login("test", self.user_pass, "dev3") # Attempt to delete a device. This works since the user just logged in. - self.delete_device(self.user_tok, "dev2", 200) + self.delete_device(self.user_tok, "dev2", HTTPStatus.OK) # Move the clock forward past the validation timeout. self.reactor.advance(6) # Deleting another devices throws the user into UI auth. - channel = self.delete_device(self.user_tok, "dev3", 401) + channel = self.delete_device(self.user_tok, "dev3", HTTPStatus.UNAUTHORIZED) # Grab the session session = channel.json_body["session"] @@ -378,7 +389,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, "dev3", - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -393,7 +404,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # due to re-using the previous session. # # Note that *no auth* information is provided, not even a session iD! - self.delete_device(self.user_tok, self.device_id, 200) + self.delete_device(self.user_tok, self.device_id, HTTPStatus.OK) @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) @@ -413,7 +424,9 @@ class UIAuthTests(unittest.HomeserverTestCase): self.assertEqual(login_resp["user_id"], self.user) # initiate a UI Auth process by attempting to delete the device - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # check that SSO is offered flows = channel.json_body["flows"] @@ -426,13 +439,13 @@ class UIAuthTests(unittest.HomeserverTestCase): ) # that should serve a confirmation page - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # and now the delete request should succeed. self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, body={"auth": {"session": session_id}}, ) @@ -445,13 +458,15 @@ class UIAuthTests(unittest.HomeserverTestCase): # now call the device deletion API: we should get the option to auth with SSO # and not password. - channel = self.delete_device(user_tok, device_id, 401) + channel = self.delete_device(user_tok, device_id, HTTPStatus.UNAUTHORIZED) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) def test_does_not_offer_sso_for_password_user(self): - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.password"]}]) @@ -463,7 +478,9 @@ class UIAuthTests(unittest.HomeserverTestCase): login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] # we have no particular expectations of ordering here @@ -480,7 +497,9 @@ class UIAuthTests(unittest.HomeserverTestCase): self.assertEqual(login_resp["user_id"], self.user) # start a UI Auth flow by attempting to delete a device - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] self.assertIn({"stages": ["m.login.sso"]}, flows) @@ -496,7 +515,10 @@ class UIAuthTests(unittest.HomeserverTestCase): # ... and the delete op should now fail with a 403 self.delete_device( - self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} + self.user_tok, + self.device_id, + HTTPStatus.FORBIDDEN, + body={"auth": {"session": session_id}}, ) @@ -551,7 +573,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): login_without_refresh = self.make_request( "POST", "/_matrix/client/r0/login", body ) - self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result) + self.assertEqual( + login_without_refresh.code, HTTPStatus.OK, login_without_refresh.result + ) self.assertNotIn("refresh_token", login_without_refresh.json_body) login_with_refresh = self.make_request( @@ -559,7 +583,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", {"refresh_token": True, **body}, ) - self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) + self.assertEqual( + login_with_refresh.code, HTTPStatus.OK, login_with_refresh.result + ) self.assertIn("refresh_token", login_with_refresh.json_body) self.assertIn("expires_in_ms", login_with_refresh.json_body) @@ -577,7 +603,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): }, ) self.assertEqual( - register_without_refresh.code, 200, register_without_refresh.result + register_without_refresh.code, + HTTPStatus.OK, + register_without_refresh.result, ) self.assertNotIn("refresh_token", register_without_refresh.json_body) @@ -591,7 +619,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "refresh_token": True, }, ) - self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) + self.assertEqual( + register_with_refresh.code, HTTPStatus.OK, register_with_refresh.result + ) self.assertIn("refresh_token", register_with_refresh.json_body) self.assertIn("expires_in_ms", register_with_refresh.json_body) @@ -610,14 +640,14 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) refresh_response = self.make_request( "POST", "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertIn("access_token", refresh_response.json_body) self.assertIn("refresh_token", refresh_response.json_body) self.assertIn("expires_in_ms", refresh_response.json_body) @@ -648,7 +678,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) self.assertApproximates( login_response.json_body["expires_in_ms"], 60 * 1000, 100 ) @@ -658,7 +688,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertApproximates( refresh_response.json_body["expires_in_ms"], 60 * 1000, 100 ) @@ -705,7 +735,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", {"refresh_token": True, **body}, ) - self.assertEqual(login_response1.code, 200, login_response1.result) + self.assertEqual(login_response1.code, HTTPStatus.OK, login_response1.result) self.assertApproximates( login_response1.json_body["expires_in_ms"], 60 * 1000, 100 ) @@ -716,7 +746,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response2.code, 200, login_response2.result) + self.assertEqual(login_response2.code, HTTPStatus.OK, login_response2.result) nonrefreshable_access_token = login_response2.json_body["access_token"] # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry) @@ -818,7 +848,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) refresh_token = login_response.json_body["refresh_token"] # Advance shy of 2 minutes into the future @@ -826,7 +856,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # Refresh our session. The refresh token should still be valid right now. refresh_response = self.use_refresh_token(refresh_token) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertIn( "refresh_token", refresh_response.json_body, @@ -846,7 +876,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This should fail because the refresh token's lifetime has also been # diminished as our session expired. refresh_response = self.use_refresh_token(refresh_token) - self.assertEqual(refresh_response.code, 403, refresh_response.result) + self.assertEqual( + refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result + ) def test_refresh_token_invalidation(self): """Refresh tokens are invalidated after first use of the next token. @@ -875,7 +907,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) # This first refresh should work properly first_refresh_response = self.make_request( @@ -884,7 +916,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - first_refresh_response.code, 200, first_refresh_response.result + first_refresh_response.code, HTTPStatus.OK, first_refresh_response.result ) # This one as well, since the token in the first one was never used @@ -894,7 +926,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - second_refresh_response.code, 200, second_refresh_response.result + second_refresh_response.code, HTTPStatus.OK, second_refresh_response.result ) # This one should not, since the token from the first refresh is not valid anymore @@ -904,7 +936,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": first_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( - third_refresh_response.code, 401, third_refresh_response.result + third_refresh_response.code, + HTTPStatus.UNAUTHORIZED, + third_refresh_response.result, ) # The associated access token should also be invalid @@ -913,7 +947,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/account/whoami", access_token=first_refresh_response.json_body["access_token"], ) - self.assertEqual(whoami_response.code, 401, whoami_response.result) + self.assertEqual( + whoami_response.code, HTTPStatus.UNAUTHORIZED, whoami_response.result + ) # But all other tokens should work (they will expire after some time) for access_token in [ @@ -923,7 +959,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): whoami_response = self.make_request( "GET", "/_matrix/client/r0/account/whoami", access_token=access_token ) - self.assertEqual(whoami_response.code, 200, whoami_response.result) + self.assertEqual( + whoami_response.code, HTTPStatus.OK, whoami_response.result + ) # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail fourth_refresh_response = self.make_request( @@ -932,7 +970,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - fourth_refresh_response.code, 403, fourth_refresh_response.result + fourth_refresh_response.code, + HTTPStatus.FORBIDDEN, + fourth_refresh_response.result, ) # But refreshing from the last valid refresh token still works @@ -942,5 +982,5 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": second_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( - fifth_refresh_response.code, 200, fifth_refresh_response.result + fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result ) -- cgit 1.5.1 From 3b8872299aac25a7e3ee5a9e00564105aa6de237 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 9 Dec 2021 13:16:01 -0500 Subject: Do not allow cross-room relations, per MSC2674. (#11516) --- changelog.d/11516.bugfix | 1 + synapse/events/utils.py | 11 ++- synapse/rest/client/relations.py | 7 +- synapse/storage/databases/main/events.py | 8 +- synapse/storage/databases/main/relations.py | 36 ++++++--- tests/rest/client/test_relations.py | 115 ++++++++++++++++++++++++++++ 6 files changed, 161 insertions(+), 17 deletions(-) create mode 100644 changelog.d/11516.bugfix (limited to 'tests/rest/client') diff --git a/changelog.d/11516.bugfix b/changelog.d/11516.bugfix new file mode 100644 index 0000000000..22bba93671 --- /dev/null +++ b/changelog.d/11516.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where relations from other rooms could be included in the bundled aggregations of an event. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 84ef69df67..3da432c1df 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -454,23 +454,26 @@ class EventClientSerializer: return event_id = event.event_id + room_id = event.room_id # The bundled aggregations to include. aggregations = {} - annotations = await self.store.get_aggregation_groups_for_event(event_id) + annotations = await self.store.get_aggregation_groups_for_event( + event_id, room_id + ) if annotations.chunk: aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() references = await self.store.get_relations_for_event( - event_id, RelationTypes.REFERENCE, direction="f" + event_id, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: aggregations[RelationTypes.REFERENCE] = references.to_dict() edit = None if event.type == EventTypes.Message: - edit = await self.store.get_applicable_edit(event_id) + edit = await self.store.get_applicable_edit(event_id, room_id) if edit: # If there is an edit replace the content, preserving existing @@ -503,7 +506,7 @@ class EventClientSerializer: ( thread_count, latest_thread_event, - ) = await self.store.get_thread_summary(event_id) + ) = await self.store.get_thread_summary(event_id, room_id) if latest_thread_event: aggregations[RelationTypes.THREAD] = { # Don't bundle aggregations as this could recurse forever. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index fc4e6921c5..ffa37ef06c 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -212,6 +212,7 @@ class RelationPaginationServlet(RestServlet): pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, + room_id=room_id, relation_type=relation_type, event_type=event_type, limit=limit, @@ -317,6 +318,7 @@ class RelationAggregationPaginationServlet(RestServlet): pagination_chunk = await self.store.get_aggregation_groups_for_event( event_id=parent_id, + room_id=room_id, event_type=event_type, limit=limit, from_token=from_token, @@ -383,7 +385,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet): # This checks that a) the event exists and b) the user is allowed to # view it. - await self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -402,6 +406,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): result = await self.store.get_relations_for_event( event_id=parent_id, + room_id=room_id, relation_type=relation_type, event_type=event_type, aggregation_key=key, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 4e528612ea..f1f4ce5e07 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1780,10 +1780,14 @@ class PersistEventsStore: ) if rel_type == RelationTypes.REPLACE: - txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) + txn.call_after( + self.store.get_applicable_edit.invalidate, (parent_id, event.room_id) + ) if rel_type == RelationTypes.THREAD: - txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) + txn.call_after( + self.store.get_thread_summary.invalidate, (parent_id, event.room_id) + ) def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): """Handles keeping track of insertion events and edges/connections. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 0a43acda07..3368a8b084 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore): async def get_relations_for_event( self, event_id: str, + room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, aggregation_key: Optional[str] = None, @@ -49,6 +50,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. aggregation_key: Only fetch events with this aggregation key, if given. @@ -63,8 +65,8 @@ class RelationsWorkerStore(SQLBaseStore): the form `{"event_id": "..."}`. """ - where_clause = ["relates_to_id = ?"] - where_args: List[Union[str, int]] = [event_id] + where_clause = ["relates_to_id = ?", "room_id = ?"] + where_args: List[Union[str, int]] = [event_id, room_id] if relation_type is not None: where_clause.append("relation_type = ?") @@ -199,6 +201,7 @@ class RelationsWorkerStore(SQLBaseStore): async def get_aggregation_groups_for_event( self, event_id: str, + room_id: str, event_type: Optional[str] = None, limit: int = 5, direction: str = "b", @@ -213,6 +216,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. event_type: Only fetch events with this event type, if given. limit: Only fetch the `limit` groups. direction: Whether to fetch the highest count first (`"b"`) or @@ -225,8 +229,12 @@ class RelationsWorkerStore(SQLBaseStore): `type`, `key` and `count` fields. """ - where_clause = ["relates_to_id = ?", "relation_type = ?"] - where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION] + where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"] + where_args: List[Union[str, int]] = [ + event_id, + room_id, + RelationTypes.ANNOTATION, + ] if event_type: where_clause.append("type = ?") @@ -288,7 +296,9 @@ class RelationsWorkerStore(SQLBaseStore): ) @cached() - async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: + async def get_applicable_edit( + self, event_id: str, room_id: str + ) -> Optional[EventBase]: """Get the most recent edit (if any) that has happened for the given event. @@ -296,6 +306,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: The original event ID + room_id: The original event's room ID Returns: The most recent edit, if any. @@ -317,13 +328,14 @@ class RelationsWorkerStore(SQLBaseStore): WHERE relates_to_id = ? AND relation_type = ? + AND edit.room_id = ? AND edit.type = 'm.room.message' ORDER by edit.origin_server_ts DESC, edit.event_id DESC LIMIT 1 """ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]: - txn.execute(sql, (event_id, RelationTypes.REPLACE)) + txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id)) row = txn.fetchone() if row: return row[0] @@ -340,13 +352,14 @@ class RelationsWorkerStore(SQLBaseStore): @cached() async def get_thread_summary( - self, event_id: str + self, event_id: str, room_id: str ) -> Tuple[int, Optional[EventBase]]: """Get the number of threaded replies, the senders of those replies, and the latest reply (if any) for the given event. Args: - event_id: The original event ID + event_id: Summarize the thread related to this event ID. + room_id: The room the event belongs to. Returns: The number of items in the thread and the most recent response, if any. @@ -363,12 +376,13 @@ class RelationsWorkerStore(SQLBaseStore): INNER JOIN events USING (event_id) WHERE relates_to_id = ? + AND room_id = ? AND relation_type = ? ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT 1 """ - txn.execute(sql, (event_id, RelationTypes.THREAD)) + txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) row = txn.fetchone() if row is None: return 0, None @@ -378,11 +392,13 @@ class RelationsWorkerStore(SQLBaseStore): sql = """ SELECT COALESCE(COUNT(event_id), 0) FROM event_relations + INNER JOIN events USING (event_id) WHERE relates_to_id = ? + AND room_id = ? AND relation_type = ? """ - txn.execute(sql, (event_id, RelationTypes.THREAD)) + txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) count = txn.fetchone()[0] # type: ignore[index] return count, latest_event_id diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 397c12c2a6..55f4f0b1d0 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -16,6 +16,7 @@ import itertools import urllib.parse from typing import Dict, List, Optional, Tuple +from unittest.mock import patch from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin @@ -23,6 +24,8 @@ from synapse.rest.client import login, register, relations, room, sync from tests import unittest from tests.server import FakeChannel +from tests.test_utils import make_awaitable +from tests.test_utils.event_injection import inject_event class RelationsTestCase(unittest.HomeserverTestCase): @@ -651,6 +654,118 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, ) + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_ignore_invalid_room(self): + """Test that we ignore invalid relations over federation.""" + # Create another room and send a message in it. + room2 = self.helper.create_room_as(self.user_id, tok=self.user_token) + res = self.helper.send(room2, body="Hi!", tok=self.user_token) + parent_id = res["event_id"] + + # Disable the validation to pretend this came over federation. + with patch( + "synapse.handlers.message.EventCreationHandler._validate_event_relation", + new=lambda self, event: make_awaitable(None), + ): + # Generate a various relations from a different room. + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.reaction", + sender=self.user_id, + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": parent_id, + "key": "A", + } + }, + ) + ) + + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.room.message", + sender=self.user_id, + content={ + "body": "foo", + "msgtype": "m.text", + "m.relates_to": { + "rel_type": RelationTypes.REFERENCE, + "event_id": parent_id, + }, + }, + ) + ) + + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.room.message", + sender=self.user_id, + content={ + "body": "foo", + "msgtype": "m.text", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": parent_id, + }, + }, + ) + ) + + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.room.message", + sender=self.user_id, + content={ + "body": "foo", + "msgtype": "m.text", + "new_content": { + "body": "new content", + "msgtype": "m.text", + }, + "m.relates_to": { + "rel_type": RelationTypes.REPLACE, + "event_id": parent_id, + }, + }, + ) + ) + + # They should be ignored when fetching relations. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) + + # And when fetching aggregations. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) + + # And for bundled aggregations. + channel = self.make_request( + "GET", + f"/rooms/{room2}/event/{parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertNotIn("m.relations", channel.json_body["unsigned"]) + def test_edit(self): """Test that a simple edit works.""" -- cgit 1.5.1 From 8391bd6ab59387845bae77130dc0ca437c37fb8e Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 10 Dec 2021 20:59:20 -0600 Subject: Test to ensure we share the same `state_group` across the whole historical batch (MSC2716) (#11487) Part of MSC2716: https://github.com/matrix-org/matrix-doc/pull/2716 We did some work on making sure the `state_groups` were shared in https://github.com/matrix-org/synapse/pull/10975 --- changelog.d/11487.misc | 1 + tests/rest/client/test_room_batch.py | 180 +++++++++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+) create mode 100644 changelog.d/11487.misc create mode 100644 tests/rest/client/test_room_batch.py (limited to 'tests/rest/client') diff --git a/changelog.d/11487.misc b/changelog.d/11487.misc new file mode 100644 index 0000000000..376b9078be --- /dev/null +++ b/changelog.d/11487.misc @@ -0,0 +1 @@ +Add test to ensure we share the same `state_group` across the whole historical batch when using the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` endpoint. diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py new file mode 100644 index 0000000000..721454c187 --- /dev/null +++ b/tests/rest/client/test_room_batch.py @@ -0,0 +1,180 @@ +import logging +from typing import List, Tuple +from unittest.mock import Mock, patch + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventContentFields, EventTypes +from synapse.appservice import ApplicationService +from synapse.rest import admin +from synapse.rest.client import login, register, room, room_batch +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests import unittest + +logger = logging.getLogger(__name__) + + +def _create_join_state_events_for_batch_send_request( + virtual_user_ids: List[str], + insert_time: int, +) -> List[JsonDict]: + return [ + { + "type": EventTypes.Member, + "sender": virtual_user_id, + "origin_server_ts": insert_time, + "content": { + "membership": "join", + "displayname": "display-name-for-%s" % (virtual_user_id,), + }, + "state_key": virtual_user_id, + } + for virtual_user_id in virtual_user_ids + ] + + +def _create_message_events_for_batch_send_request( + virtual_user_id: str, insert_time: int, count: int +) -> List[JsonDict]: + return [ + { + "type": EventTypes.Message, + "sender": virtual_user_id, + "origin_server_ts": insert_time, + "content": { + "msgtype": "m.text", + "body": "Historical %d" % (i), + EventContentFields.MSC2716_HISTORICAL: True, + }, + } + for i in range(count) + ] + + +class RoomBatchTestCase(unittest.HomeserverTestCase): + """Test importing batches of historical messages.""" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room_batch.register_servlets, + room.register_servlets, + register.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + + self.appservice = ApplicationService( + token="i_am_an_app_service", + hostname="test", + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + # Note: this user does not have to match the regex above + sender="@as_main:test", + ) + + mock_load_appservices = Mock(return_value=[self.appservice]) + with patch( + "synapse.storage.databases.main.appservice.load_appservices", + mock_load_appservices, + ): + hs = self.setup_test_homeserver(config=config) + return hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.clock = clock + self.storage = hs.get_storage() + + self.virtual_user_id = self.register_appservice_user( + "as_user_potato", self.appservice.token + ) + + def _create_test_room(self) -> Tuple[str, str, str, str]: + room_id = self.helper.create_room_as( + self.appservice.sender, tok=self.appservice.token + ) + + res_a = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "A", + }, + tok=self.appservice.token, + ) + event_id_a = res_a["event_id"] + + res_b = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "B", + }, + tok=self.appservice.token, + ) + event_id_b = res_b["event_id"] + + res_c = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "C", + }, + tok=self.appservice.token, + ) + event_id_c = res_c["event_id"] + + return room_id, event_id_a, event_id_b, event_id_c + + @unittest.override_config({"experimental_features": {"msc2716_enabled": True}}) + def test_same_state_groups_for_whole_historical_batch(self): + """Make sure that when using the `/batch_send` endpoint to import a + bunch of historical messages, it re-uses the same `state_group` across + the whole batch. This is an easy optimization to make sure we're getting + right because the state for the whole batch is contained in + `state_events_at_start` and can be shared across everything. + """ + + time_before_room = int(self.clock.time_msec()) + room_id, event_id_a, _, _ = self._create_test_room() + + channel = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2716/rooms/%s/batch_send?prev_event_id=%s" + % (room_id, event_id_a), + content={ + "events": _create_message_events_for_batch_send_request( + self.virtual_user_id, time_before_room, 3 + ), + "state_events_at_start": _create_join_state_events_for_batch_send_request( + [self.virtual_user_id], time_before_room + ), + }, + access_token=self.appservice.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + # Get the historical event IDs that we just imported + historical_event_ids = channel.json_body["event_ids"] + self.assertEqual(len(historical_event_ids), 3) + + # Fetch the state_groups + state_group_map = self.get_success( + self.storage.state.get_state_groups_ids(room_id, historical_event_ids) + ) + + # We expect all of the historical events to be using the same state_group + # so there should only be a single state_group here! + self.assertEqual( + len(state_group_map.keys()), + 1, + "Expected a single state_group to be returned by saw state_groups=%s" + % (state_group_map.keys(),), + ) -- cgit 1.5.1 From 76aa5537ad4f41cad130862a230c1f4cc4bfcbcf Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 20 Dec 2021 16:33:35 +0000 Subject: Disable aggregation bundling on `/sync` responses (#11583) * Disable aggregation bundling on `/sync` responses A partial revert of #11478. This turns out to have had a significant CPU impact on initial-sync handling. For now, let's disable it, until we find a more efficient way of achieving this. * Fix tests. Co-authored-by: Patrick Cloke --- changelog.d/11583.bugfix | 1 + synapse/rest/client/sync.py | 10 +++++++++- tests/rest/client/test_relations.py | 10 +++++----- 3 files changed, 15 insertions(+), 6 deletions(-) create mode 100644 changelog.d/11583.bugfix (limited to 'tests/rest/client') diff --git a/changelog.d/11583.bugfix b/changelog.d/11583.bugfix new file mode 100644 index 0000000000..d2ed113e21 --- /dev/null +++ b/changelog.d/11583.bugfix @@ -0,0 +1 @@ +Fix a performance regression in `/sync` handling, introduced in 1.49.0. diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 88e4f5e063..7f5846d389 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -522,7 +522,15 @@ class SyncRestServlet(RestServlet): time_now=time_now, # Don't bother to bundle aggregations if the timeline is unlimited, # as clients will have all the necessary information. - bundle_aggregations=room.timeline.limited, + # bundle_aggregations=room.timeline.limited, + # + # richvdh 2021-12-15: disable this temporarily as it has too high an + # overhead for initialsyncs. We need to figure out a way that the + # bundling can be done *before* the events are stored in the + # SyncResponseCache so that this part can be synchronous. + # + # Ensure to re-enable the test at tests/rest/client/test_relations.py::RelationsTestCase.test_bundled_aggregations. + bundle_aggregations=False, token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 397c12c2a6..1b58b73136 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -574,11 +574,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations")) # Request sync. - channel = self.make_request("GET", "/sync", access_token=self.user_token) - self.assertEquals(200, channel.code, channel.json_body) - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - self.assertTrue(room_timeline["limited"]) - _find_and_assert_event(room_timeline["events"]) + # channel = self.make_request("GET", "/sync", access_token=self.user_token) + # self.assertEquals(200, channel.code, channel.json_body) + # room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + # self.assertTrue(room_timeline["limited"]) + # _find_and_assert_event(room_timeline["events"]) # Note that /relations is tested separately in test_aggregation_get_event_for_thread # since it needs different data configured. -- cgit 1.5.1