summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
authorOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2022-01-05 14:19:39 +0000
committerOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2022-01-05 14:19:39 +0000
commit717a5c085a593f00b9454e0155e16f0466b77fd3 (patch)
treef92d46b057c88443443409a8fd53e5c749917bd9 /tests/rest/client
parentMerge branch 'rav/no_bundle_aggregations_in_sync' into matrix-org-hotfixes (diff)
parentMention drop of support in changelog (diff)
downloadsynapse-717a5c085a593f00b9454e0155e16f0466b77fd3.tar.xz
Merge branch 'release-v1.50' into matrix-org-hotfixes
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/test_auth.py134
-rw-r--r--tests/rest/client/test_relations.py125
-rw-r--r--tests/rest/client/test_room_batch.py180
3 files changed, 387 insertions, 52 deletions
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py

index 72bbc87b4a..27cb856b0a 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py
@@ -85,7 +85,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", "auth/m.login.recaptcha/fallback/web?session=" + session ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) channel = self.make_request( "POST", @@ -104,7 +104,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): """Ensure that fallback auth via a captcha works.""" # Returns a 401 as per the spec channel = self.register( - 401, + HTTPStatus.UNAUTHORIZED, {"username": "user", "type": "m.login.password", "password": "bar"}, ) @@ -116,15 +116,17 @@ class FallbackAuthTests(unittest.HomeserverTestCase): ) # Complete the recaptcha step. - self.recaptcha(session, 200) + self.recaptcha(session, HTTPStatus.OK) # also complete the dummy auth - self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) + self.register( + HTTPStatus.OK, {"auth": {"session": session, "type": "m.login.dummy"}} + ) # Now we should have fulfilled a complete auth flow, including # the recaptcha fallback step, we can then send a # request to the register API with the session in the authdict. - channel = self.register(200, {"auth": {"session": session}}) + channel = self.register(HTTPStatus.OK, {"auth": {"session": session}}) # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") @@ -137,7 +139,8 @@ class FallbackAuthTests(unittest.HomeserverTestCase): # will be used.) # Returns a 401 as per the spec channel = self.register( - 401, {"username": "user", "type": "m.login.password", "password": "bar"} + HTTPStatus.UNAUTHORIZED, + {"username": "user", "type": "m.login.password", "password": "bar"}, ) # Grab the session @@ -231,7 +234,9 @@ class UIAuthTests(unittest.HomeserverTestCase): """ # Attempt to delete this device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # Grab the session session = channel.json_body["session"] @@ -242,7 +247,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -260,14 +265,16 @@ class UIAuthTests(unittest.HomeserverTestCase): UIA - check that still works. """ - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) session = channel.json_body["session"] # Make another request providing the UI auth flow. self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -293,7 +300,9 @@ class UIAuthTests(unittest.HomeserverTestCase): # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_devices(401, {"devices": [self.device_id]}) + channel = self.delete_devices( + HTTPStatus.UNAUTHORIZED, {"devices": [self.device_id]} + ) # Grab the session session = channel.json_body["session"] @@ -303,7 +312,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow, but try to delete the # second device. self.delete_devices( - 200, + HTTPStatus.OK, { "devices": ["dev2"], "auth": { @@ -324,7 +333,9 @@ class UIAuthTests(unittest.HomeserverTestCase): # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # Grab the session session = channel.json_body["session"] @@ -338,7 +349,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, "dev2", - 403, + HTTPStatus.FORBIDDEN, { "auth": { "type": "m.login.password", @@ -361,13 +372,13 @@ class UIAuthTests(unittest.HomeserverTestCase): self.login("test", self.user_pass, "dev3") # Attempt to delete a device. This works since the user just logged in. - self.delete_device(self.user_tok, "dev2", 200) + self.delete_device(self.user_tok, "dev2", HTTPStatus.OK) # Move the clock forward past the validation timeout. self.reactor.advance(6) # Deleting another devices throws the user into UI auth. - channel = self.delete_device(self.user_tok, "dev3", 401) + channel = self.delete_device(self.user_tok, "dev3", HTTPStatus.UNAUTHORIZED) # Grab the session session = channel.json_body["session"] @@ -378,7 +389,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, "dev3", - 200, + HTTPStatus.OK, { "auth": { "type": "m.login.password", @@ -393,7 +404,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # due to re-using the previous session. # # Note that *no auth* information is provided, not even a session iD! - self.delete_device(self.user_tok, self.device_id, 200) + self.delete_device(self.user_tok, self.device_id, HTTPStatus.OK) @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) @@ -413,7 +424,9 @@ class UIAuthTests(unittest.HomeserverTestCase): self.assertEqual(login_resp["user_id"], self.user) # initiate a UI Auth process by attempting to delete the device - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) # check that SSO is offered flows = channel.json_body["flows"] @@ -426,13 +439,13 @@ class UIAuthTests(unittest.HomeserverTestCase): ) # that should serve a confirmation page - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # and now the delete request should succeed. self.delete_device( self.user_tok, self.device_id, - 200, + HTTPStatus.OK, body={"auth": {"session": session_id}}, ) @@ -445,13 +458,15 @@ class UIAuthTests(unittest.HomeserverTestCase): # now call the device deletion API: we should get the option to auth with SSO # and not password. - channel = self.delete_device(user_tok, device_id, 401) + channel = self.delete_device(user_tok, device_id, HTTPStatus.UNAUTHORIZED) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) def test_does_not_offer_sso_for_password_user(self): - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.password"]}]) @@ -463,7 +478,9 @@ class UIAuthTests(unittest.HomeserverTestCase): login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] # we have no particular expectations of ordering here @@ -480,7 +497,9 @@ class UIAuthTests(unittest.HomeserverTestCase): self.assertEqual(login_resp["user_id"], self.user) # start a UI Auth flow by attempting to delete a device - channel = self.delete_device(self.user_tok, self.device_id, 401) + channel = self.delete_device( + self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED + ) flows = channel.json_body["flows"] self.assertIn({"stages": ["m.login.sso"]}, flows) @@ -496,7 +515,10 @@ class UIAuthTests(unittest.HomeserverTestCase): # ... and the delete op should now fail with a 403 self.delete_device( - self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} + self.user_tok, + self.device_id, + HTTPStatus.FORBIDDEN, + body={"auth": {"session": session_id}}, ) @@ -551,7 +573,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): login_without_refresh = self.make_request( "POST", "/_matrix/client/r0/login", body ) - self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result) + self.assertEqual( + login_without_refresh.code, HTTPStatus.OK, login_without_refresh.result + ) self.assertNotIn("refresh_token", login_without_refresh.json_body) login_with_refresh = self.make_request( @@ -559,7 +583,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", {"refresh_token": True, **body}, ) - self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) + self.assertEqual( + login_with_refresh.code, HTTPStatus.OK, login_with_refresh.result + ) self.assertIn("refresh_token", login_with_refresh.json_body) self.assertIn("expires_in_ms", login_with_refresh.json_body) @@ -577,7 +603,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): }, ) self.assertEqual( - register_without_refresh.code, 200, register_without_refresh.result + register_without_refresh.code, + HTTPStatus.OK, + register_without_refresh.result, ) self.assertNotIn("refresh_token", register_without_refresh.json_body) @@ -591,7 +619,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "refresh_token": True, }, ) - self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) + self.assertEqual( + register_with_refresh.code, HTTPStatus.OK, register_with_refresh.result + ) self.assertIn("refresh_token", register_with_refresh.json_body) self.assertIn("expires_in_ms", register_with_refresh.json_body) @@ -610,14 +640,14 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) refresh_response = self.make_request( "POST", "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertIn("access_token", refresh_response.json_body) self.assertIn("refresh_token", refresh_response.json_body) self.assertIn("expires_in_ms", refresh_response.json_body) @@ -648,7 +678,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) self.assertApproximates( login_response.json_body["expires_in_ms"], 60 * 1000, 100 ) @@ -658,7 +688,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertApproximates( refresh_response.json_body["expires_in_ms"], 60 * 1000, 100 ) @@ -705,7 +735,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", {"refresh_token": True, **body}, ) - self.assertEqual(login_response1.code, 200, login_response1.result) + self.assertEqual(login_response1.code, HTTPStatus.OK, login_response1.result) self.assertApproximates( login_response1.json_body["expires_in_ms"], 60 * 1000, 100 ) @@ -716,7 +746,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response2.code, 200, login_response2.result) + self.assertEqual(login_response2.code, HTTPStatus.OK, login_response2.result) nonrefreshable_access_token = login_response2.json_body["access_token"] # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry) @@ -818,7 +848,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) refresh_token = login_response.json_body["refresh_token"] # Advance shy of 2 minutes into the future @@ -826,7 +856,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # Refresh our session. The refresh token should still be valid right now. refresh_response = self.use_refresh_token(refresh_token) - self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result) self.assertIn( "refresh_token", refresh_response.json_body, @@ -846,7 +876,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This should fail because the refresh token's lifetime has also been # diminished as our session expired. refresh_response = self.use_refresh_token(refresh_token) - self.assertEqual(refresh_response.code, 403, refresh_response.result) + self.assertEqual( + refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result + ) def test_refresh_token_invalidation(self): """Refresh tokens are invalidated after first use of the next token. @@ -875,7 +907,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/login", body, ) - self.assertEqual(login_response.code, 200, login_response.result) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) # This first refresh should work properly first_refresh_response = self.make_request( @@ -884,7 +916,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - first_refresh_response.code, 200, first_refresh_response.result + first_refresh_response.code, HTTPStatus.OK, first_refresh_response.result ) # This one as well, since the token in the first one was never used @@ -894,7 +926,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - second_refresh_response.code, 200, second_refresh_response.result + second_refresh_response.code, HTTPStatus.OK, second_refresh_response.result ) # This one should not, since the token from the first refresh is not valid anymore @@ -904,7 +936,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": first_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( - third_refresh_response.code, 401, third_refresh_response.result + third_refresh_response.code, + HTTPStatus.UNAUTHORIZED, + third_refresh_response.result, ) # The associated access token should also be invalid @@ -913,7 +947,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "/_matrix/client/r0/account/whoami", access_token=first_refresh_response.json_body["access_token"], ) - self.assertEqual(whoami_response.code, 401, whoami_response.result) + self.assertEqual( + whoami_response.code, HTTPStatus.UNAUTHORIZED, whoami_response.result + ) # But all other tokens should work (they will expire after some time) for access_token in [ @@ -923,7 +959,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): whoami_response = self.make_request( "GET", "/_matrix/client/r0/account/whoami", access_token=access_token ) - self.assertEqual(whoami_response.code, 200, whoami_response.result) + self.assertEqual( + whoami_response.code, HTTPStatus.OK, whoami_response.result + ) # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail fourth_refresh_response = self.make_request( @@ -932,7 +970,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( - fourth_refresh_response.code, 403, fourth_refresh_response.result + fourth_refresh_response.code, + HTTPStatus.FORBIDDEN, + fourth_refresh_response.result, ) # But refreshing from the last valid refresh token still works @@ -942,5 +982,5 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": second_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( - fifth_refresh_response.code, 200, fifth_refresh_response.result + fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result ) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 397c12c2a6..c026d526ef 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py
@@ -16,6 +16,7 @@ import itertools import urllib.parse from typing import Dict, List, Optional, Tuple +from unittest.mock import patch from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin @@ -23,6 +24,8 @@ from synapse.rest.client import login, register, relations, room, sync from tests import unittest from tests.server import FakeChannel +from tests.test_utils import make_awaitable +from tests.test_utils.event_injection import inject_event class RelationsTestCase(unittest.HomeserverTestCase): @@ -574,11 +577,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations")) # Request sync. - channel = self.make_request("GET", "/sync", access_token=self.user_token) - self.assertEquals(200, channel.code, channel.json_body) - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - self.assertTrue(room_timeline["limited"]) - _find_and_assert_event(room_timeline["events"]) + # channel = self.make_request("GET", "/sync", access_token=self.user_token) + # self.assertEquals(200, channel.code, channel.json_body) + # room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + # self.assertTrue(room_timeline["limited"]) + # _find_and_assert_event(room_timeline["events"]) # Note that /relations is tested separately in test_aggregation_get_event_for_thread # since it needs different data configured. @@ -651,6 +654,118 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, ) + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_ignore_invalid_room(self): + """Test that we ignore invalid relations over federation.""" + # Create another room and send a message in it. + room2 = self.helper.create_room_as(self.user_id, tok=self.user_token) + res = self.helper.send(room2, body="Hi!", tok=self.user_token) + parent_id = res["event_id"] + + # Disable the validation to pretend this came over federation. + with patch( + "synapse.handlers.message.EventCreationHandler._validate_event_relation", + new=lambda self, event: make_awaitable(None), + ): + # Generate a various relations from a different room. + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.reaction", + sender=self.user_id, + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": parent_id, + "key": "A", + } + }, + ) + ) + + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.room.message", + sender=self.user_id, + content={ + "body": "foo", + "msgtype": "m.text", + "m.relates_to": { + "rel_type": RelationTypes.REFERENCE, + "event_id": parent_id, + }, + }, + ) + ) + + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.room.message", + sender=self.user_id, + content={ + "body": "foo", + "msgtype": "m.text", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": parent_id, + }, + }, + ) + ) + + self.get_success( + inject_event( + self.hs, + room_id=self.room, + type="m.room.message", + sender=self.user_id, + content={ + "body": "foo", + "msgtype": "m.text", + "new_content": { + "body": "new content", + "msgtype": "m.text", + }, + "m.relates_to": { + "rel_type": RelationTypes.REPLACE, + "event_id": parent_id, + }, + }, + ) + ) + + # They should be ignored when fetching relations. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) + + # And when fetching aggregations. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(channel.json_body["chunk"], []) + + # And for bundled aggregations. + channel = self.make_request( + "GET", + f"/rooms/{room2}/event/{parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertNotIn("m.relations", channel.json_body["unsigned"]) + def test_edit(self): """Test that a simple edit works.""" diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py new file mode 100644
index 0000000000..721454c187 --- /dev/null +++ b/tests/rest/client/test_room_batch.py
@@ -0,0 +1,180 @@ +import logging +from typing import List, Tuple +from unittest.mock import Mock, patch + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventContentFields, EventTypes +from synapse.appservice import ApplicationService +from synapse.rest import admin +from synapse.rest.client import login, register, room, room_batch +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests import unittest + +logger = logging.getLogger(__name__) + + +def _create_join_state_events_for_batch_send_request( + virtual_user_ids: List[str], + insert_time: int, +) -> List[JsonDict]: + return [ + { + "type": EventTypes.Member, + "sender": virtual_user_id, + "origin_server_ts": insert_time, + "content": { + "membership": "join", + "displayname": "display-name-for-%s" % (virtual_user_id,), + }, + "state_key": virtual_user_id, + } + for virtual_user_id in virtual_user_ids + ] + + +def _create_message_events_for_batch_send_request( + virtual_user_id: str, insert_time: int, count: int +) -> List[JsonDict]: + return [ + { + "type": EventTypes.Message, + "sender": virtual_user_id, + "origin_server_ts": insert_time, + "content": { + "msgtype": "m.text", + "body": "Historical %d" % (i), + EventContentFields.MSC2716_HISTORICAL: True, + }, + } + for i in range(count) + ] + + +class RoomBatchTestCase(unittest.HomeserverTestCase): + """Test importing batches of historical messages.""" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room_batch.register_servlets, + room.register_servlets, + register.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + + self.appservice = ApplicationService( + token="i_am_an_app_service", + hostname="test", + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + # Note: this user does not have to match the regex above + sender="@as_main:test", + ) + + mock_load_appservices = Mock(return_value=[self.appservice]) + with patch( + "synapse.storage.databases.main.appservice.load_appservices", + mock_load_appservices, + ): + hs = self.setup_test_homeserver(config=config) + return hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.clock = clock + self.storage = hs.get_storage() + + self.virtual_user_id = self.register_appservice_user( + "as_user_potato", self.appservice.token + ) + + def _create_test_room(self) -> Tuple[str, str, str, str]: + room_id = self.helper.create_room_as( + self.appservice.sender, tok=self.appservice.token + ) + + res_a = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "A", + }, + tok=self.appservice.token, + ) + event_id_a = res_a["event_id"] + + res_b = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "B", + }, + tok=self.appservice.token, + ) + event_id_b = res_b["event_id"] + + res_c = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "C", + }, + tok=self.appservice.token, + ) + event_id_c = res_c["event_id"] + + return room_id, event_id_a, event_id_b, event_id_c + + @unittest.override_config({"experimental_features": {"msc2716_enabled": True}}) + def test_same_state_groups_for_whole_historical_batch(self): + """Make sure that when using the `/batch_send` endpoint to import a + bunch of historical messages, it re-uses the same `state_group` across + the whole batch. This is an easy optimization to make sure we're getting + right because the state for the whole batch is contained in + `state_events_at_start` and can be shared across everything. + """ + + time_before_room = int(self.clock.time_msec()) + room_id, event_id_a, _, _ = self._create_test_room() + + channel = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2716/rooms/%s/batch_send?prev_event_id=%s" + % (room_id, event_id_a), + content={ + "events": _create_message_events_for_batch_send_request( + self.virtual_user_id, time_before_room, 3 + ), + "state_events_at_start": _create_join_state_events_for_batch_send_request( + [self.virtual_user_id], time_before_room + ), + }, + access_token=self.appservice.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + # Get the historical event IDs that we just imported + historical_event_ids = channel.json_body["event_ids"] + self.assertEqual(len(historical_event_ids), 3) + + # Fetch the state_groups + state_group_map = self.get_success( + self.storage.state.get_state_groups_ids(room_id, historical_event_ids) + ) + + # We expect all of the historical events to be using the same state_group + # so there should only be a single state_group here! + self.assertEqual( + len(state_group_map.keys()), + 1, + "Expected a single state_group to be returned by saw state_groups=%s" + % (state_group_map.keys(),), + )