diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/federation/test_federation_client.py | 149 | ||||
-rw-r--r-- | tests/federation/transport/test_client.py | 32 | ||||
-rw-r--r-- | tests/handlers/test_password_providers.py | 123 | ||||
-rw-r--r-- | tests/push/test_http.py | 129 | ||||
-rw-r--r-- | tests/rest/client/test_account.py | 9 | ||||
-rw-r--r-- | tests/rest/client/test_auth.py | 93 | ||||
-rw-r--r-- | tests/rest/client/test_device_lists.py | 155 | ||||
-rw-r--r-- | tests/rest/client/test_relations.py | 42 | ||||
-rw-r--r-- | tests/rest/client/utils.py | 6 | ||||
-rw-r--r-- | tests/storage/databases/test_state_store.py | 352 | ||||
-rw-r--r-- | tests/storage/test_events.py | 107 | ||||
-rw-r--r-- | tests/storage/test_state.py | 109 | ||||
-rw-r--r-- | tests/unittest.py | 21 |
13 files changed, 1226 insertions, 101 deletions
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py new file mode 100644 index 0000000000..ec8864dafe --- /dev/null +++ b/tests/federation/test_federation_client.py @@ -0,0 +1,149 @@ +# Copyright 2022 Matrix.org Federation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest import mock + +import twisted.web.client +from twisted.internet import defer +from twisted.internet.protocol import Protocol +from twisted.python.failure import Failure +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.room_versions import RoomVersions +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests.unittest import FederatingHomeserverTestCase + + +class FederationClientTest(FederatingHomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + super().prepare(reactor, clock, homeserver) + + # mock out the Agent used by the federation client, which is easier than + # catching the HTTPS connection and do the TLS stuff. + self._mock_agent = mock.create_autospec(twisted.web.client.Agent, spec_set=True) + homeserver.get_federation_http_client().agent = self._mock_agent + + def test_get_room_state(self): + creator = f"@creator:{self.OTHER_SERVER_NAME}" + test_room_id = "!room_id" + + # mock up some events to use in the response. + # In real life, these would have things in `prev_events` and `auth_events`, but that's + # a bit annoying to mock up, and the code under test doesn't care, so we don't bother. + create_event_dict = self.add_hashes_and_signatures( + { + "room_id": test_room_id, + "type": "m.room.create", + "state_key": "", + "sender": creator, + "content": {"creator": creator}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 500, + } + ) + member_event_dict = self.add_hashes_and_signatures( + { + "room_id": test_room_id, + "type": "m.room.member", + "sender": creator, + "state_key": creator, + "content": {"membership": "join"}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 600, + } + ) + pl_event_dict = self.add_hashes_and_signatures( + { + "room_id": test_room_id, + "type": "m.room.power_levels", + "sender": creator, + "state_key": "", + "content": {}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 700, + } + ) + + # mock up the response, and have the agent return it + self._mock_agent.request.return_value = defer.succeed( + _mock_response( + { + "pdus": [ + create_event_dict, + member_event_dict, + pl_event_dict, + ], + "auth_chain": [ + create_event_dict, + member_event_dict, + ], + } + ) + ) + + # now fire off the request + state_resp, auth_resp = self.get_success( + self.hs.get_federation_client().get_room_state( + "yet_another_server", + test_room_id, + "event_id", + RoomVersions.V9, + ) + ) + + # check the right call got made to the agent + self._mock_agent.request.assert_called_once_with( + b"GET", + b"matrix://yet_another_server/_matrix/federation/v1/state/%21room_id?event_id=event_id", + headers=mock.ANY, + bodyProducer=None, + ) + + # ... and that the response is correct. + + # the auth_resp should be empty because all the events are also in state + self.assertEqual(auth_resp, []) + + # all of the events should be returned in state_resp, though not necessarily + # in the same order. We just check the type on the assumption that if the type + # is right, so is the rest of the event. + self.assertCountEqual( + [e.type for e in state_resp], + ["m.room.create", "m.room.member", "m.room.power_levels"], + ) + + +def _mock_response(resp: JsonDict): + body = json.dumps(resp).encode("utf-8") + + def deliver_body(p: Protocol): + p.dataReceived(body) + p.connectionLost(Failure(twisted.web.client.ResponseDone())) + + response = mock.Mock( + code=200, + phrase=b"OK", + headers=twisted.web.client.Headers({"content-Type": ["application/json"]}), + length=len(body), + deliverBody=deliver_body, + ) + mock.seal(response) + return response diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py index a7031a55f2..c2320ce133 100644 --- a/tests/federation/transport/test_client.py +++ b/tests/federation/transport/test_client.py @@ -62,3 +62,35 @@ class SendJoinParserTestCase(TestCase): self.assertEqual(len(parsed_response.state), 1, parsed_response) self.assertEqual(parsed_response.event_dict, {}, parsed_response) self.assertIsNone(parsed_response.event, parsed_response) + self.assertFalse(parsed_response.partial_state, parsed_response) + self.assertEqual(parsed_response.servers_in_room, None, parsed_response) + + def test_partial_state(self) -> None: + """Check that the partial_state flag is correctly parsed""" + parser = SendJoinParser(RoomVersions.V1, False) + response = { + "org.matrix.msc3706.partial_state": True, + } + + serialised_response = json.dumps(response).encode() + + # Send data to the parser + parser.write(serialised_response) + + # Retrieve and check the parsed SendJoinResponse + parsed_response = parser.finish() + self.assertTrue(parsed_response.partial_state) + + def test_servers_in_room(self) -> None: + """Check that the servers_in_room field is correctly parsed""" + parser = SendJoinParser(RoomVersions.V1, False) + response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]} + + serialised_response = json.dumps(response).encode() + + # Send data to the parser + parser.write(serialised_response) + + # Retrieve and check the parsed SendJoinResponse + parsed_response = parser.finish() + self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"]) diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 4740dd0a65..49d832de81 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -84,7 +84,7 @@ class CustomAuthProvider: def __init__(self, config, api: ModuleApi): api.register_password_auth_provider_callbacks( - auth_checkers={("test.login_type", ("test_field",)): self.check_auth}, + auth_checkers={("test.login_type", ("test_field",)): self.check_auth} ) def check_auth(self, *args): @@ -122,7 +122,7 @@ class PasswordCustomAuthProvider: auth_checkers={ ("test.login_type", ("test_field",)): self.check_auth, ("m.login.password", ("password",)): self.check_auth, - }, + } ) pass @@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): account.register_servlets, ] + CALLBACK_USERNAME = "get_username_for_registration" + CALLBACK_DISPLAYNAME = "get_displayname_for_registration" + def setUp(self): # we use a global mock device, so make sure we are starting with a clean slate mock_password_provider.reset_mock() @@ -754,7 +757,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): """Tests that the get_username_for_registration callback can define the username of a user when registering. """ - self._setup_get_username_for_registration() + self._setup_get_name_for_registration( + callback_name=self.CALLBACK_USERNAME, + ) username = "rin" channel = self.make_request( @@ -777,30 +782,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): """Tests that the get_username_for_registration callback is only called at the end of the UIA flow. """ - m = self._setup_get_username_for_registration() - - # Initiate the UIA flow. - username = "rin" - channel = self.make_request( - "POST", - "register", - {"username": username, "type": "m.login.password", "password": "bar"}, + m = self._setup_get_name_for_registration( + callback_name=self.CALLBACK_USERNAME, ) - self.assertEqual(channel.code, 401) - self.assertIn("session", channel.json_body) - # Check that the callback hasn't been called yet. - m.assert_not_called() + username = "rin" + res = self._do_uia_assert_mock_not_called(username, m) - # Finish the UIA flow. - session = channel.json_body["session"] - channel = self.make_request( - "POST", - "register", - {"auth": {"session": session, "type": LoginType.DUMMY}}, - ) - self.assertEqual(channel.code, 200, channel.json_body) - mxid = channel.json_body["user_id"] + mxid = res["user_id"] self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") # Check that the callback has been called. @@ -817,6 +806,56 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self._test_3pid_allowed("rin", False) self._test_3pid_allowed("kitay", True) + def test_displayname(self): + """Tests that the get_displayname_for_registration callback can define the + display name of a user when registering. + """ + self._setup_get_name_for_registration( + callback_name=self.CALLBACK_DISPLAYNAME, + ) + + username = "rin" + channel = self.make_request( + "POST", + "/register", + { + "username": username, + "password": "bar", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(channel.code, 200) + + # Our callback takes the username and appends "-foo" to it, check that's what we + # have. + user_id = UserID.from_string(channel.json_body["user_id"]) + display_name = self.get_success( + self.hs.get_profile_handler().get_displayname(user_id) + ) + + self.assertEqual(display_name, username + "-foo") + + def test_displayname_uia(self): + """Tests that the get_displayname_for_registration callback is only called at the + end of the UIA flow. + """ + m = self._setup_get_name_for_registration( + callback_name=self.CALLBACK_DISPLAYNAME, + ) + + username = "rin" + res = self._do_uia_assert_mock_not_called(username, m) + + user_id = UserID.from_string(res["user_id"]) + display_name = self.get_success( + self.hs.get_profile_handler().get_displayname(user_id) + ) + + self.assertEqual(display_name, username + "-foo") + + # Check that the callback has been called. + m.assert_called_once() + def _test_3pid_allowed(self, username: str, registration: bool): """Tests that the "is_3pid_allowed" module callback is called correctly, using either /register or /account URLs depending on the arguments. @@ -877,23 +916,47 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): m.assert_called_once_with("email", "bar@test.com", registration) - def _setup_get_username_for_registration(self) -> Mock: - """Registers a get_username_for_registration callback that appends "-foo" to the - username the client is trying to register. + def _setup_get_name_for_registration(self, callback_name: str) -> Mock: + """Registers either a get_username_for_registration callback or a + get_displayname_for_registration callback that appends "-foo" to the username the + client is trying to register. """ - async def get_username_for_registration(uia_results, params): + async def callback(uia_results, params): self.assertIn(LoginType.DUMMY, uia_results) username = params["username"] return username + "-foo" - m = Mock(side_effect=get_username_for_registration) + m = Mock(side_effect=callback) password_auth_provider = self.hs.get_password_auth_provider() - password_auth_provider.get_username_for_registration_callbacks.append(m) + getattr(password_auth_provider, callback_name + "_callbacks").append(m) return m + def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict: + # Initiate the UIA flow. + channel = self.make_request( + "POST", + "register", + {"username": username, "type": "m.login.password", "password": "bar"}, + ) + self.assertEqual(channel.code, 401) + self.assertIn("session", channel.json_body) + + # Check that the callback hasn't been called yet. + m.assert_not_called() + + # Finish the UIA flow. + session = channel.json_body["session"] + channel = self.make_request( + "POST", + "register", + {"auth": {"session": session, "type": LoginType.DUMMY}}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + return channel.json_body + def _get_login_flows(self) -> JsonDict: channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index c068d329a9..e1e3fb97c5 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -571,9 +571,7 @@ class HTTPPusherTests(HomeserverTestCase): # Carry out our option-value specific test # # This push should still only contain an unread count of 1 (for 1 unread room) - self.assertEqual( - self.push_attempts[5][2]["notification"]["counts"]["unread"], 1 - ) + self._check_push_attempt(6, 1) @override_config({"push": {"group_unread_count_by_room": False}}) def test_push_unread_count_message_count(self): @@ -585,11 +583,9 @@ class HTTPPusherTests(HomeserverTestCase): # Carry out our option-value specific test # - # We're counting every unread message, so there should now be 4 since the + # We're counting every unread message, so there should now be 3 since the # last read receipt - self.assertEqual( - self.push_attempts[5][2]["notification"]["counts"]["unread"], 4 - ) + self._check_push_attempt(6, 3) def _test_push_unread_count(self): """ @@ -597,8 +593,9 @@ class HTTPPusherTests(HomeserverTestCase): Note that: * Sending messages will cause push notifications to go out to relevant users - * Sending a read receipt will cause a "badge update" notification to go out to - the user that sent the receipt + * Sending a read receipt will cause the HTTP pusher to check whether the unread + count has changed since the last push notification. If so, a "badge update" + notification goes out to the user that sent the receipt """ # Register the user who gets notified user_id = self.register_user("user", "pass") @@ -642,24 +639,74 @@ class HTTPPusherTests(HomeserverTestCase): # position in the room. We'll set the read position to this event in a moment first_message_event_id = response["event_id"] - # Advance time a bit (so the pusher will register something has happened) and - # make the push succeed - self.push_attempts[0][0].callback({}) + expected_push_attempts = 1 + self._check_push_attempt(expected_push_attempts, 0) + + self._send_read_request(access_token, first_message_event_id, room_id) + + # Unread count has not changed. Therefore, ensure that read request does not + # trigger a push notification. + self.assertEqual(len(self.push_attempts), 1) + + # Send another message + response2 = self.helper.send( + room_id, body="How's the weather today?", tok=other_access_token + ) + second_message_event_id = response2["event_id"] + + expected_push_attempts += 1 + + self._check_push_attempt(expected_push_attempts, 1) + + self._send_read_request(access_token, second_message_event_id, room_id) + expected_push_attempts += 1 + + self._check_push_attempt(expected_push_attempts, 0) + + # If we're grouping by room, sending more messages shouldn't increase the + # unread count, as they're all being sent in the same room. Otherwise, it + # should. Therefore, the last call to _check_push_attempt is done in the + # caller method. + self.helper.send(room_id, body="Hello?", tok=other_access_token) + expected_push_attempts += 1 + + self._advance_time_and_make_push_succeed(expected_push_attempts) + + self.helper.send(room_id, body="Hello??", tok=other_access_token) + expected_push_attempts += 1 + + self._advance_time_and_make_push_succeed(expected_push_attempts) + + self.helper.send(room_id, body="HELLO???", tok=other_access_token) + + def _advance_time_and_make_push_succeed(self, expected_push_attempts): self.pump() + self.push_attempts[expected_push_attempts - 1][0].callback({}) + def _check_push_attempt( + self, expected_push_attempts: int, expected_unread_count_last_push: int + ) -> None: + """ + Makes sure that the last expected push attempt succeeds and checks whether + it contains the expected unread count. + """ + self._advance_time_and_make_push_succeed(expected_push_attempts) # Check our push made it - self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(len(self.push_attempts), expected_push_attempts) + _, push_url, push_body = self.push_attempts[expected_push_attempts - 1] self.assertEqual( - self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" + push_url, + "http://example.com/_matrix/push/v1/notify", ) - # Check that the unread count for the room is 0 # # The unread count is zero as the user has no read receipt in the room yet self.assertEqual( - self.push_attempts[0][2]["notification"]["counts"]["unread"], 0 + push_body["notification"]["counts"]["unread"], + expected_unread_count_last_push, ) + def _send_read_request(self, access_token, message_event_id, room_id): # Now set the user's read receipt position to the first event # # This will actually trigger a new notification to be sent out so that @@ -667,56 +714,8 @@ class HTTPPusherTests(HomeserverTestCase): # count goes down channel = self.make_request( "POST", - "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id), + "/rooms/%s/receipt/m.read/%s" % (room_id, message_event_id), {}, access_token=access_token, ) self.assertEqual(channel.code, 200, channel.json_body) - - # Advance time and make the push succeed - self.push_attempts[1][0].callback({}) - self.pump() - - # Unread count is still zero as we've read the only message in the room - self.assertEqual(len(self.push_attempts), 2) - self.assertEqual( - self.push_attempts[1][2]["notification"]["counts"]["unread"], 0 - ) - - # Send another message - self.helper.send( - room_id, body="How's the weather today?", tok=other_access_token - ) - - # Advance time and make the push succeed - self.push_attempts[2][0].callback({}) - self.pump() - - # This push should contain an unread count of 1 as there's now been one - # message since our last read receipt - self.assertEqual(len(self.push_attempts), 3) - self.assertEqual( - self.push_attempts[2][2]["notification"]["counts"]["unread"], 1 - ) - - # Since we're grouping by room, sending more messages shouldn't increase the - # unread count, as they're all being sent in the same room - self.helper.send(room_id, body="Hello?", tok=other_access_token) - - # Advance time and make the push succeed - self.pump() - self.push_attempts[3][0].callback({}) - - self.helper.send(room_id, body="Hello??", tok=other_access_token) - - # Advance time and make the push succeed - self.pump() - self.push_attempts[4][0].callback({}) - - self.helper.send(room_id, body="HELLO???", tok=other_access_token) - - # Advance time and make the push succeed - self.pump() - self.push_attempts[5][0].callback({}) - - self.assertEqual(len(self.push_attempts), 6) diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 89d85b0a17..51146c471d 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -486,8 +486,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase): { "user_id": user_id, "device_id": device_id, - # Unstable until MSC3069 enters spec + # MSC3069 entered spec in Matrix 1.2 but maintained compatibility "org.matrix.msc3069.is_guest": False, + "is_guest": False, }, ) @@ -505,8 +506,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase): { "user_id": user_id, "device_id": device_id, - # Unstable until MSC3069 enters spec + # MSC3069 entered spec in Matrix 1.2 but maintained compatibility "org.matrix.msc3069.is_guest": True, + "is_guest": True, }, ) @@ -528,8 +530,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase): whoami, { "user_id": user_id, - # Unstable until MSC3069 enters spec + # MSC3069 entered spec in Matrix 1.2 but maintained compatibility "org.matrix.msc3069.is_guest": False, + "is_guest": False, }, ) self.assertFalse(hasattr(whoami, "device_id")) diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 27cb856b0a..4a68d66573 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. from http import HTTPStatus -from typing import Optional, Union +from typing import Optional, Tuple, Union from twisted.internet.defer import succeed import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.rest.client import account, auth, devices, login, register +from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict, UserID from tests import unittest @@ -527,6 +528,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): auth.register_servlets, account.register_servlets, login.register_servlets, + logout.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, register.register_servlets, ] @@ -984,3 +986,90 @@ class RefreshAuthTests(unittest.HomeserverTestCase): self.assertEqual( fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result ) + + def test_many_token_refresh(self): + """ + If a refresh is performed many times during a session, there shouldn't be + extra 'cruft' built up over time. + + This test was written specifically to troubleshoot a case where logout + was very slow if a lot of refreshes had been performed for the session. + """ + + def _refresh(refresh_token: str) -> Tuple[str, str]: + """ + Performs one refresh, returning the next refresh token and access token. + """ + refresh_response = self.use_refresh_token(refresh_token) + self.assertEqual( + refresh_response.code, HTTPStatus.OK, refresh_response.result + ) + return ( + refresh_response.json_body["refresh_token"], + refresh_response.json_body["access_token"], + ) + + def _table_length(table_name: str) -> int: + """ + Helper to get the size of a table, in rows. + For testing only; trivially vulnerable to SQL injection. + """ + + def _txn(txn: LoggingTransaction) -> int: + txn.execute(f"SELECT COUNT(1) FROM {table_name}") + row = txn.fetchone() + # Query is infallible + assert row is not None + return row[0] + + return self.get_success( + self.hs.get_datastores().main.db_pool.runInteraction( + "_table_length", _txn + ) + ) + + # Before we log in, there are no access tokens. + self.assertEqual(_table_length("access_tokens"), 0) + self.assertEqual(_table_length("refresh_tokens"), 0) + + body = { + "type": "m.login.password", + "user": "test", + "password": self.user_pass, + "refresh_token": True, + } + login_response = self.make_request( + "POST", + "/_matrix/client/v3/login", + body, + ) + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) + + access_token = login_response.json_body["access_token"] + refresh_token = login_response.json_body["refresh_token"] + + # Now that we have logged in, there should be one access token and one + # refresh token + self.assertEqual(_table_length("access_tokens"), 1) + self.assertEqual(_table_length("refresh_tokens"), 1) + + for _ in range(5): + refresh_token, access_token = _refresh(refresh_token) + + # After 5 sequential refreshes, there should only be the latest two + # refresh/access token pairs. + # (The last one is preserved because it's in use! + # The one before that is preserved because it can still be used to + # replace the last token pair, in case of e.g. a network interruption.) + self.assertEqual(_table_length("access_tokens"), 2) + self.assertEqual(_table_length("refresh_tokens"), 2) + + logout_response = self.make_request( + "POST", "/_matrix/client/v3/logout", {}, access_token=access_token + ) + self.assertEqual(logout_response.code, HTTPStatus.OK, logout_response.result) + + # Now that we have logged in, there should be no access token + # and no refresh token + self.assertEqual(_table_length("access_tokens"), 0) + self.assertEqual(_table_length("refresh_tokens"), 0) diff --git a/tests/rest/client/test_device_lists.py b/tests/rest/client/test_device_lists.py new file mode 100644 index 0000000000..16070cf027 --- /dev/null +++ b/tests/rest/client/test_device_lists.py @@ -0,0 +1,155 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.rest import admin, devices, room, sync +from synapse.rest.client import account, login, register + +from tests import unittest + + +class DeviceListsTestCase(unittest.HomeserverTestCase): + """Tests regarding device list changes.""" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + account.register_servlets, + room.register_servlets, + sync.register_servlets, + devices.register_servlets, + ] + + def test_receiving_local_device_list_changes(self): + """Tests that a local users that share a room receive each other's device list + changes. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") + + # Create a room for them to coexist peacefully in + new_room_id = self.helper.create_room_as( + alice_user_id, is_public=True, tok=alice_access_token + ) + self.assertIsNotNone(new_room_id) + + # Have Bob join the room + self.helper.invite( + new_room_id, alice_user_id, bob_user_id, tok=alice_access_token + ) + self.helper.join(new_room_id, bob_user_id, tok=bob_access_token) + + # Now have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + "/sync", + access_token=bob_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch_token = channel.json_body["next_batch"] + + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"/sync?since={next_batch_token}&timeout=30000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) + + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that bob's incremental sync contains the updated device list. + # If not, the client would only receive the device list update on the + # *next* sync. + bob_sync_channel.await_result() + self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + + changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( + "changed", [] + ) + self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body) + + def test_not_receiving_local_device_list_changes(self): + """Tests a local users DO NOT receive device updates from each other if they do not + share a room. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") + + # These users do not share a room. They are lonely. + + # Have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + "/sync", + access_token=bob_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch_token = channel.json_body["next_batch"] + + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"/sync?since={next_batch_token}&timeout=1000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) + + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that bob's incremental sync does not contain the updated device list. + bob_sync_channel.await_result() + self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + + changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( + "changed", [] + ) + self.assertNotIn( + alice_user_id, changed_device_lists, bob_sync_channel.json_body + ) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index de80aca037..dfd9ffcb93 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1123,6 +1123,48 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_edit_thread(self): + """Test that editing a thread works.""" + + # Create a thread and edit the last event. + channel = self._send_relation( + RelationTypes.THREAD, + "m.room.message", + content={"msgtype": "m.text", "body": "A threaded reply!"}, + ) + self.assertEquals(200, channel.code, channel.json_body) + threaded_event_id = channel.json_body["event_id"] + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + parent_id=threaded_event_id, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Fetch the thread root, to get the bundled aggregation for the thread. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect that the edit message appears in the thread summary in the + # unsigned relations section. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.THREAD, relations_dict) + + thread_summary = relations_dict[RelationTypes.THREAD] + self.assertIn("latest_event", thread_summary) + latest_event_in_thread = thread_summary["latest_event"] + self.assertEquals( + latest_event_in_thread["content"]["body"], "I've been edited!" + ) + def test_edit_edit(self): """Test that an edit cannot be edited.""" new_body = {"msgtype": "m.text", "body": "Initial edit"} diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 1c0cb0cf4f..2b3fdadffa 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -106,9 +106,13 @@ class RestHelper: default room version. tok: The access token to use in the request. expect_code: The expected HTTP response code. + extra_content: Extra keys to include in the body of the /createRoom request. + Note that if is_public is set, the "visibility" key will be overridden. + If room_version is set, the "room_version" key will be overridden. + custom_headers: HTTP headers to include in the request. Returns: - The ID of the newly created room. + The ID of the newly created room, or None if the request failed. """ temp_id = self.auth_user_id self.auth_user_id = room_creator diff --git a/tests/storage/databases/test_state_store.py b/tests/storage/databases/test_state_store.py new file mode 100644 index 0000000000..076b660809 --- /dev/null +++ b/tests/storage/databases/test_state_store.py @@ -0,0 +1,352 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import typing +from typing import Dict, List, Sequence, Tuple +from unittest.mock import patch + +from twisted.internet.defer import Deferred, ensureDeferred +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventTypes +from synapse.storage.databases.state.store import MAX_INFLIGHT_REQUESTS_PER_GROUP +from synapse.storage.state import StateFilter +from synapse.types import StateMap +from synapse.util import Clock + +from tests.unittest import HomeserverTestCase + +if typing.TYPE_CHECKING: + from synapse.server import HomeServer + +# StateFilter for ALL non-m.room.member state events +ALL_NON_MEMBERS_STATE_FILTER = StateFilter.freeze( + types={EventTypes.Member: set()}, + include_others=True, +) + +FAKE_STATE = { + (EventTypes.Member, "@alice:test"): "join", + (EventTypes.Member, "@bob:test"): "leave", + (EventTypes.Member, "@charlie:test"): "invite", + ("test.type", "a"): "AAA", + ("test.type", "b"): "BBB", + ("other.event.type", "state.key"): "123", +} + + +class StateGroupInflightCachingTestCase(HomeserverTestCase): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: "HomeServer" + ) -> None: + self.state_storage = homeserver.get_storage().state + self.state_datastore = homeserver.get_datastores().state + # Patch out the `_get_state_groups_from_groups`. + # This is useful because it lets us pretend we have a slow database. + get_state_groups_patch = patch.object( + self.state_datastore, + "_get_state_groups_from_groups", + self._fake_get_state_groups_from_groups, + ) + get_state_groups_patch.start() + + self.addCleanup(get_state_groups_patch.stop) + self.get_state_group_calls: List[ + Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]] + ] = [] + + def _fake_get_state_groups_from_groups( + self, groups: Sequence[int], state_filter: StateFilter + ) -> "Deferred[Dict[int, StateMap[str]]]": + d: Deferred[Dict[int, StateMap[str]]] = Deferred() + self.get_state_group_calls.append((tuple(groups), state_filter, d)) + return d + + def _complete_request_fake( + self, + groups: Tuple[int, ...], + state_filter: StateFilter, + d: "Deferred[Dict[int, StateMap[str]]]", + ) -> None: + """ + Assemble a fake database response and complete the database request. + """ + + # Return a filtered copy of the fake state + d.callback({group: state_filter.filter_state(FAKE_STATE) for group in groups}) + + def test_duplicate_requests_deduplicated(self) -> None: + """ + Tests that duplicate requests for state are deduplicated. + + This test: + - requests some state (state group 42, 'all' state filter) + - requests it again, before the first request finishes + - checks to see that only one database query was made + - completes the database query + - checks that both requests see the same retrieved state + """ + req1 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # This should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + + req2 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # No more calls should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + self.assertFalse(req2.called) + + groups, sf, d = self.get_state_group_calls[0] + self.assertEqual(groups, (42,)) + self.assertEqual(sf, StateFilter.all()) + + # Now we can complete the request + self._complete_request_fake(groups, sf, d) + + self.assertEqual(self.get_success(req1), FAKE_STATE) + self.assertEqual(self.get_success(req2), FAKE_STATE) + + def test_smaller_request_deduplicated(self) -> None: + """ + Tests that duplicate requests for state are deduplicated. + + This test: + - requests some state (state group 42, 'all' state filter) + - requests a subset of that state, before the first request finishes + - checks to see that only one database query was made + - completes the database query + - checks that both requests see the correct retrieved state + """ + req1 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.from_types((("test.type", None),)) + ) + ) + self.pump(by=0.1) + + # This should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + + req2 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.from_types((("test.type", "b"),)) + ) + ) + self.pump(by=0.1) + + # No more calls should have gone to the database, because the second + # request was already in the in-flight cache! + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + self.assertFalse(req2.called) + + groups, sf, d = self.get_state_group_calls[0] + self.assertEqual(groups, (42,)) + # The state filter is expanded internally for increased cache hit rate, + # so we the database sees a wider state filter than requested. + self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER) + + # Now we can complete the request + self._complete_request_fake(groups, sf, d) + + self.assertEqual( + self.get_success(req1), + {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"}, + ) + self.assertEqual(self.get_success(req2), {("test.type", "b"): "BBB"}) + + def test_partially_overlapping_request_deduplicated(self) -> None: + """ + Tests that partially-overlapping requests are partially deduplicated. + + This test: + - requests a single type of wildcard state + (This is internally expanded to be all non-member state) + - requests the entire state in parallel + - checks to see that two database queries were made, but that the second + one is only for member state. + - completes the database queries + - checks that both requests have the correct result. + """ + + req1 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.from_types((("test.type", None),)) + ) + ) + self.pump(by=0.1) + + # This should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + + req2 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # Because it only partially overlaps, this also went to the database + self.assertEqual(len(self.get_state_group_calls), 2) + self.assertFalse(req1.called) + self.assertFalse(req2.called) + + # First request: + groups, sf, d = self.get_state_group_calls[0] + self.assertEqual(groups, (42,)) + # The state filter is expanded internally for increased cache hit rate, + # so we the database sees a wider state filter than requested. + self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER) + self._complete_request_fake(groups, sf, d) + + # Second request: + groups, sf, d = self.get_state_group_calls[1] + self.assertEqual(groups, (42,)) + # The state filter is narrowed to only request membership state, because + # the remainder of the state is already being queried in the first request! + self.assertEqual( + sf, StateFilter.freeze({EventTypes.Member: None}, include_others=False) + ) + self._complete_request_fake(groups, sf, d) + + # Check the results are correct + self.assertEqual( + self.get_success(req1), + {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"}, + ) + self.assertEqual(self.get_success(req2), FAKE_STATE) + + def test_in_flight_requests_stop_being_in_flight(self) -> None: + """ + Tests that in-flight request deduplication doesn't somehow 'hold on' + to completed requests: once they're done, they're taken out of the + in-flight cache. + """ + req1 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # This should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + + # Complete the request right away. + self._complete_request_fake(*self.get_state_group_calls[0]) + self.assertTrue(req1.called) + + # Send off another request + req2 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # It should have gone to the database again, because the previous request + # isn't in-flight and therefore isn't available for deduplication. + self.assertEqual(len(self.get_state_group_calls), 2) + self.assertFalse(req2.called) + + # Complete the request right away. + self._complete_request_fake(*self.get_state_group_calls[1]) + self.assertTrue(req2.called) + groups, sf, d = self.get_state_group_calls[0] + + self.assertEqual(self.get_success(req1), FAKE_STATE) + self.assertEqual(self.get_success(req2), FAKE_STATE) + + def test_inflight_requests_capped(self) -> None: + """ + Tests that the number of in-flight requests is capped to 5. + + - requests several pieces of state separately + (5 to hit the limit, 1 to 'shunt out', another that comes after the + group has been 'shunted out') + - checks to see that the torrent of requests is shunted out by + rewriting one of the filters as the 'all' state filter + - requests after that one do not cause any additional queries + """ + # 5 at the time of writing. + CAP_COUNT = MAX_INFLIGHT_REQUESTS_PER_GROUP + + reqs = [] + + # Request 7 different keys (1 to 7) of the `some.state` type. + for req_id in range(CAP_COUNT + 2): + reqs.append( + ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, + StateFilter.freeze( + {"some.state": {str(req_id + 1)}}, include_others=False + ), + ) + ) + ) + self.pump(by=0.1) + + # There should only be 6 calls to the database, not 7. + self.assertEqual(len(self.get_state_group_calls), CAP_COUNT + 1) + + # Assert that the first 5 are exact requests for the individual pieces + # wanted + for req_id in range(CAP_COUNT): + groups, sf, d = self.get_state_group_calls[req_id] + self.assertEqual( + sf, + StateFilter.freeze( + {"some.state": {str(req_id + 1)}}, include_others=False + ), + ) + + # The 6th request should be the 'all' state filter + groups, sf, d = self.get_state_group_calls[CAP_COUNT] + self.assertEqual(sf, StateFilter.all()) + + # Complete the queries and check which requests complete as a result + for req_id in range(CAP_COUNT): + # This request should not have been completed yet + self.assertFalse(reqs[req_id].called) + + groups, sf, d = self.get_state_group_calls[req_id] + self._complete_request_fake(groups, sf, d) + + # This should have only completed this one request + self.assertTrue(reqs[req_id].called) + + # Now complete the final query; the last 2 requests should complete + # as a result + self.assertFalse(reqs[CAP_COUNT].called) + self.assertFalse(reqs[CAP_COUNT + 1].called) + groups, sf, d = self.get_state_group_calls[CAP_COUNT] + self._complete_request_fake(groups, sf, d) + self.assertTrue(reqs[CAP_COUNT].called) + self.assertTrue(reqs[CAP_COUNT + 1].called) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index f462a8b1c7..a8639d8f82 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -329,3 +329,110 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check the new extremity is just the new remote event. self.assert_extremities([local_message_event_id, remote_event_2.event_id]) + + +class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.state = self.hs.get_state_handler() + self.persistence = self.hs.get_storage().persistence + self.store = self.hs.get_datastore() + + def test_remote_user_rooms_cache_invalidated(self): + """Test that if the server leaves a room the `get_rooms_for_user` cache + is invalidated for remote users. + """ + + # Set up a room with a local and remote user in it. + user_id = self.register_user("user", "pass") + token = self.login("user", "pass") + + room_id = self.helper.create_room_as( + "user", room_version=RoomVersions.V6.identifier, tok=token + ) + + body = self.helper.send(room_id, body="Test", tok=token) + local_message_event_id = body["event_id"] + + # Fudge a join event for a remote user. + remote_user = "@user:other" + remote_event_1 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": remote_user, + "content": {"membership": Membership.JOIN}, + "room_id": room_id, + "sender": remote_user, + "depth": 5, + "prev_events": [local_message_event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + context = self.get_success(self.state.compute_event_context(remote_event_1)) + self.get_success(self.persistence.persist_event(remote_event_1, context)) + + # Call `get_rooms_for_user` to add the remote user to the cache + rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) + self.assertEqual(set(rooms), {room_id}) + + # Now we have the local server leave the room, and check that calling + # `get_user_in_room` for the remote user no longer includes the room. + self.helper.leave(room_id, user_id, tok=token) + + rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) + self.assertEqual(set(rooms), set()) + + def test_room_remote_user_cache_invalidated(self): + """Test that if the server leaves a room the `get_users_in_room` cache + is invalidated for remote users. + """ + + # Set up a room with a local and remote user in it. + user_id = self.register_user("user", "pass") + token = self.login("user", "pass") + + room_id = self.helper.create_room_as( + "user", room_version=RoomVersions.V6.identifier, tok=token + ) + + body = self.helper.send(room_id, body="Test", tok=token) + local_message_event_id = body["event_id"] + + # Fudge a join event for a remote user. + remote_user = "@user:other" + remote_event_1 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": remote_user, + "content": {"membership": Membership.JOIN}, + "room_id": room_id, + "sender": remote_user, + "depth": 5, + "prev_events": [local_message_event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + context = self.get_success(self.state.compute_event_context(remote_event_1)) + self.get_success(self.persistence.persist_event(remote_event_1, context)) + + # Call `get_users_in_room` to add the remote user to the cache + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual(set(users), {user_id, remote_user}) + + # Now we have the local server leave the room, and check that calling + # `get_user_in_room` for the remote user no longer includes the room. + self.helper.leave(room_id, user_id, tok=token) + + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual(users, []) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 70d52b088c..28c767ecfd 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -992,3 +992,112 @@ class StateFilterDifferenceTestCase(TestCase): StateFilter.none(), StateFilter.all(), ) + + +class StateFilterTestCase(TestCase): + def test_return_expanded(self): + """ + Tests the behaviour of the return_expanded() function that expands + StateFilters to include more state types (for the sake of cache hit rate). + """ + + self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all()) + + self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none()) + + # Concrete-only state filters stay the same + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ), + ) + + # Concrete-only state filters stay the same + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + {"some.other.state.type": {""}}, include_others=False + ).return_expanded(), + StateFilter.freeze({"some.other.state.type": {""}}, include_others=False), + ) + + # Concrete-only state filters stay the same + # (Case: member-only filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ), + ) + + # Wildcard member-only state filters stay the same + self.assertEqual( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:test", "@alicia:test"}}, + include_others=True, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + "yet.another.state.type": {"wombat"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) diff --git a/tests/unittest.py b/tests/unittest.py index a71892cb9d..7983c1e8b8 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -51,7 +51,10 @@ from twisted.web.server import Request from synapse import events from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.config.homeserver import HomeServerConfig +from synapse.config.server import DEFAULT_ROOM_VERSION +from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.federation.transport.server import TransportLayerServer from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite @@ -839,6 +842,24 @@ class FederatingHomeserverTestCase(HomeserverTestCase): client_ip=client_ip, ) + def add_hashes_and_signatures( + self, + event_dict: JsonDict, + room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + ) -> JsonDict: + """Adds hashes and signatures to the given event dict + + Returns: + The modified event dict, for convenience + """ + add_hashes_and_signatures( + room_version, + event_dict, + signature_name=self.OTHER_SERVER_NAME, + signing_key=self.OTHER_SERVER_SIGNATURE_KEY, + ) + return event_dict + def _auth_header_for_request( origin: str, |