diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/api/test_filtering.py | 107 | ||||
-rw-r--r-- | tests/handlers/test_appservice.py | 8 | ||||
-rw-r--r-- | tests/handlers/test_password_providers.py | 5 | ||||
-rw-r--r-- | tests/handlers/test_register.py | 9 | ||||
-rw-r--r-- | tests/handlers/test_room_summary.py | 55 | ||||
-rw-r--r-- | tests/handlers/test_sync.py | 5 | ||||
-rw-r--r-- | tests/replication/_base.py | 5 | ||||
-rw-r--r-- | tests/rest/admin/test_admin.py | 48 | ||||
-rw-r--r-- | tests/rest/admin/test_room.py | 754 | ||||
-rw-r--r-- | tests/rest/admin/test_user.py | 26 | ||||
-rw-r--r-- | tests/rest/client/test_directory.py | 105 | ||||
-rw-r--r-- | tests/rest/client/test_login.py | 5 | ||||
-rw-r--r-- | tests/rest/client/test_rooms.py | 154 | ||||
-rw-r--r-- | tests/rest/client/utils.py | 71 | ||||
-rw-r--r-- | tests/server.py | 18 | ||||
-rw-r--r-- | tests/storage/test_profile.py | 9 | ||||
-rw-r--r-- | tests/storage/test_roommember.py | 48 | ||||
-rw-r--r-- | tests/storage/test_stream.py | 207 | ||||
-rw-r--r-- | tests/test_federation.py | 2 | ||||
-rw-r--r-- | tests/unittest.py | 63 |
20 files changed, 1506 insertions, 198 deletions
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index f44c91a373..b7fc33dc94 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -15,6 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch + import jsonschema from synapse.api.constants import EventContentFields @@ -51,9 +53,8 @@ class FilteringTestCase(unittest.HomeserverTestCase): {"presence": {"senders": ["@bar;pik.test.com"]}}, ] for filter in invalid_filters: - with self.assertRaises(SynapseError) as check_filter_error: + with self.assertRaises(SynapseError): self.filtering.check_valid_filter(filter) - self.assertIsInstance(check_filter_error.exception, SynapseError) def test_valid_filters(self): valid_filters = [ @@ -119,12 +120,12 @@ class FilteringTestCase(unittest.HomeserverTestCase): definition = {"types": ["m.room.message", "org.matrix.foo.bar"]} event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_definition_types_works_with_wildcards(self): definition = {"types": ["m.*", "org.matrix.foo.bar"]} event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_definition_types_works_with_unknowns(self): definition = {"types": ["m.room.message", "org.matrix.foo.bar"]} @@ -133,24 +134,24 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="now.for.something.completely.different", room_id="!foo:bar", ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_not_types_works_with_literals(self): definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]} event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_not_types_works_with_wildcards(self): definition = {"not_types": ["m.room.message", "org.matrix.*"]} event = MockEvent( sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar" ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_not_types_works_with_unknowns(self): definition = {"not_types": ["m.*", "org.*"]} event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar") - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_definition_not_types_takes_priority_over_types(self): definition = { @@ -158,35 +159,35 @@ class FilteringTestCase(unittest.HomeserverTestCase): "types": ["m.room.message", "m.room.topic"], } event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_senders_works_with_literals(self): definition = {"senders": ["@flibble:wibble"]} event = MockEvent( sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar" ) - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_definition_senders_works_with_unknowns(self): definition = {"senders": ["@flibble:wibble"]} event = MockEvent( sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar" ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_not_senders_works_with_literals(self): definition = {"not_senders": ["@flibble:wibble"]} event = MockEvent( sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar" ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_not_senders_works_with_unknowns(self): definition = {"not_senders": ["@flibble:wibble"]} event = MockEvent( sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar" ) - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_definition_not_senders_takes_priority_over_senders(self): definition = { @@ -196,14 +197,14 @@ class FilteringTestCase(unittest.HomeserverTestCase): event = MockEvent( sender="@misspiggy:muppets", type="m.room.topic", room_id="!foo:bar" ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_rooms_works_with_literals(self): definition = {"rooms": ["!secretbase:unknown"]} event = MockEvent( sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown" ) - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_definition_rooms_works_with_unknowns(self): definition = {"rooms": ["!secretbase:unknown"]} @@ -212,7 +213,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="m.room.message", room_id="!anothersecretbase:unknown", ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_not_rooms_works_with_literals(self): definition = {"not_rooms": ["!anothersecretbase:unknown"]} @@ -221,7 +222,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="m.room.message", room_id="!anothersecretbase:unknown", ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_not_rooms_works_with_unknowns(self): definition = {"not_rooms": ["!secretbase:unknown"]} @@ -230,7 +231,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="m.room.message", room_id="!anothersecretbase:unknown", ) - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_definition_not_rooms_takes_priority_over_rooms(self): definition = { @@ -240,7 +241,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): event = MockEvent( sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown" ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_combined_event(self): definition = { @@ -256,7 +257,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="m.room.message", # yup room_id="!stage:unknown", # yup ) - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_definition_combined_event_bad_sender(self): definition = { @@ -272,7 +273,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="m.room.message", # yup room_id="!stage:unknown", # yup ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_combined_event_bad_room(self): definition = { @@ -288,7 +289,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="m.room.message", # yup room_id="!piggyshouse:muppets", # nope ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_definition_combined_event_bad_type(self): definition = { @@ -304,7 +305,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): type="muppets.misspiggy.kisses", # nope room_id="!stage:unknown", # yup ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_filter_labels(self): definition = {"org.matrix.labels": ["#fun"]} @@ -315,7 +316,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): content={EventContentFields.LABELS: ["#fun"]}, ) - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) event = MockEvent( sender="@foo:bar", @@ -324,7 +325,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): content={EventContentFields.LABELS: ["#notfun"]}, ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) def test_filter_not_labels(self): definition = {"org.matrix.not_labels": ["#fun"]} @@ -335,7 +336,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): content={EventContentFields.LABELS: ["#fun"]}, ) - self.assertFalse(Filter(definition).check(event)) + self.assertFalse(Filter(self.hs, definition)._check(event)) event = MockEvent( sender="@foo:bar", @@ -344,7 +345,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): content={EventContentFields.LABELS: ["#notfun"]}, ) - self.assertTrue(Filter(definition).check(event)) + self.assertTrue(Filter(self.hs, definition)._check(event)) def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} @@ -362,7 +363,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) ) - results = user_filter.filter_presence(events=events) + results = self.get_success(user_filter.filter_presence(events=events)) self.assertEquals(events, results) def test_filter_presence_no_match(self): @@ -386,7 +387,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) ) - results = user_filter.filter_presence(events=events) + results = self.get_success(user_filter.filter_presence(events=events)) self.assertEquals([], results) def test_filter_room_state_match(self): @@ -405,7 +406,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) ) - results = user_filter.filter_room_state(events=events) + results = self.get_success(user_filter.filter_room_state(events=events)) self.assertEquals(events, results) def test_filter_room_state_no_match(self): @@ -426,7 +427,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) ) - results = user_filter.filter_room_state(events) + results = self.get_success(user_filter.filter_room_state(events)) self.assertEquals([], results) def test_filter_rooms(self): @@ -441,10 +442,52 @@ class FilteringTestCase(unittest.HomeserverTestCase): "!not_included:example.com", # Disallowed because not in rooms. ] - filtered_room_ids = list(Filter(definition).filter_rooms(room_ids)) + filtered_room_ids = list(Filter(self.hs, definition).filter_rooms(room_ids)) self.assertEquals(filtered_room_ids, ["!allowed:example.com"]) + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_filter_relations(self): + events = [ + # An event without a relation. + MockEvent( + event_id="$no_relation", + sender="@foo:bar", + type="org.matrix.custom.event", + room_id="!foo:bar", + ), + # An event with a relation. + MockEvent( + event_id="$with_relation", + sender="@foo:bar", + type="org.matrix.custom.event", + room_id="!foo:bar", + ), + # Non-EventBase objects get passed through. + {}, + ] + + # For the following tests we patch the datastore method (intead of injecting + # events). This is a bit cheeky, but tests the logic of _check_event_relations. + + # Filter for a particular sender. + definition = { + "io.element.relation_senders": ["@foo:bar"], + } + + async def events_have_relations(*args, **kwargs): + return ["$with_relation"] + + with patch.object( + self.datastore, "events_have_relations", new=events_have_relations + ): + filtered_events = list( + self.get_success( + Filter(self.hs, definition)._check_event_relations(events) + ) + ) + self.assertEquals(filtered_events, events[1:]) + def test_add_filter(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 1f6a924452..d6f14e2dba 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -272,7 +272,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): make_awaitable(([event], None)) ) - self.handler.notify_interested_services_ephemeral("receipt_key", 580) + self.handler.notify_interested_services_ephemeral( + "receipt_key", 580, ["@fakerecipient:example.com"] + ) self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with( interested_service, [event] ) @@ -300,7 +302,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): make_awaitable(([event], None)) ) - self.handler.notify_interested_services_ephemeral("receipt_key", 579) + self.handler.notify_interested_services_ephemeral( + "receipt_key", 580, ["@fakerecipient:example.com"] + ) self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called() def _mkservice(self, is_interested, protocols=None): diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 7dd4a5a367..08e9730d4d 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -31,7 +31,10 @@ from tests.unittest import override_config # (possibly experimental) login flows we expect to appear in the list after the normal # ones -ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] +ADDITIONAL_LOGIN_FLOWS = [ + {"type": "m.login.application_service"}, + {"type": "uk.half-shot.msc2778.login.application_service"}, +] # a mock instance which the dummy auth providers delegate to, so we can see what's going # on diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index db691c4c1c..cd6f2c77ae 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -193,7 +193,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_not_blocked(self): - self.store.count_monthly_users = Mock( + # Type ignore: mypy doesn't like us assigning to methods. + self.store.count_monthly_users = Mock( # type: ignore[assignment] return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) ) # Ensure does not throw exception @@ -201,7 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_blocked(self): - self.store.get_monthly_active_count = Mock( + # Type ignore: mypy doesn't like us assigning to methods. + self.store.get_monthly_active_count = Mock( # type: ignore[assignment] return_value=make_awaitable(self.lots_of_users) ) self.get_failure( @@ -209,7 +211,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ResourceLimitError, ) - self.store.get_monthly_active_count = Mock( + # Type ignore: mypy doesn't like us assigning to methods. + self.store.get_monthly_active_count = Mock( # type: ignore[assignment] return_value=make_awaitable(self.hs.config.server.max_mau_value) ) self.get_failure( diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index d3d0bf1ac5..7b95844b55 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -14,6 +14,8 @@ from typing import Any, Iterable, List, Optional, Tuple from unittest import mock +from twisted.internet.defer import ensureDeferred + from synapse.api.constants import ( EventContentFields, EventTypes, @@ -316,6 +318,59 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): AuthError, ) + def test_room_hierarchy_cache(self) -> None: + """In-flight room hierarchy requests are deduplicated.""" + # Run two `get_room_hierarchy` calls up until they block. + deferred1 = ensureDeferred( + self.handler.get_room_hierarchy(self.user, self.space) + ) + deferred2 = ensureDeferred( + self.handler.get_room_hierarchy(self.user, self.space) + ) + + # Complete the two calls. + result1 = self.get_success(deferred1) + result2 = self.get_success(deferred2) + + # Both `get_room_hierarchy` calls should return the same result. + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_hierarchy(result1, expected) + self._assert_hierarchy(result2, expected) + self.assertIs(result1, result2) + + # A subsequent `get_room_hierarchy` call should not reuse the result. + result3 = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result3, expected) + self.assertIsNot(result1, result3) + + def test_room_hierarchy_cache_sharing(self) -> None: + """Room hierarchy responses for different users are not shared.""" + user2 = self.register_user("user2", "pass") + + # Make the room within the space invite-only. + self.helper.send_state( + self.room, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) + + # Run two `get_room_hierarchy` calls for different users up until they block. + deferred1 = ensureDeferred( + self.handler.get_room_hierarchy(self.user, self.space) + ) + deferred2 = ensureDeferred(self.handler.get_room_hierarchy(user2, self.space)) + + # Complete the two calls. + result1 = self.get_success(deferred1) + result2 = self.get_success(deferred2) + + # The `get_room_hierarchy` calls should return different results. + self._assert_hierarchy(result1, [(self.space, [self.room]), (self.room, ())]) + self._assert_hierarchy(result2, [(self.space, [self.room])]) + def _create_room_with_join_rule( self, join_rule: str, room_version: Optional[str] = None, **extra_content ) -> str: diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 339c039914..638186f173 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -13,10 +13,11 @@ # limitations under the License. from typing import Optional +from unittest.mock import Mock from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, ResourceLimitError -from synapse.api.filtering import DEFAULT_FILTER_COLLECTION +from synapse.api.filtering import Filtering from synapse.api.room_versions import RoomVersions from synapse.handlers.sync import SyncConfig from synapse.rest import admin @@ -197,7 +198,7 @@ def generate_sync_config( _request_key += 1 return SyncConfig( user=UserID.from_string(user_id), - filter_collection=DEFAULT_FILTER_COLLECTION, + filter_collection=Filtering(Mock()).DEFAULT_FILTER_COLLECTION, is_guest=False, request_key=("request_key", _request_key), device_id=device_id, diff --git a/tests/replication/_base.py b/tests/replication/_base.py index eac4664b41..cb02eddf07 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from twisted.internet.protocol import Protocol from twisted.web.resource import Resource from synapse.app.generic_worker import GenericWorkerServer -from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite from synapse.replication.http import ReplicationRestResource from synapse.replication.tcp.client import ReplicationDataHandler @@ -220,8 +219,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): unlike `BaseStreamTestCase`. """ - servlets: List[Callable[[HomeServer, JsonResource], None]] = [] - def setUp(self): super().setUp() diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 192073c520..af849bd471 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -474,3 +474,51 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): % server_and_media_id_2 ), ) + + +class PurgeHistoryTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = f"/_synapse/admin/v1/purge_history/{self.room_id}" + self.url_status = "/_synapse/admin/v1/purge_history_status/" + + def test_purge_history(self): + """ + Simple test of purge history API. + Test only that is is possible to call, get status 200 and purge_id. + """ + + channel = self.make_request( + "POST", + self.url, + content={"delete_local_events": True, "purge_up_to_ts": 0}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("purge_id", channel.json_body) + purge_id = channel.json_body["purge_id"] + + # get status + channel = self.make_request( + "GET", + self.url_status + purge_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("complete", channel.json_body["status"]) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 46116644ce..b48fc12e5f 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -14,12 +14,16 @@ import json import urllib.parse +from http import HTTPStatus from typing import List, Optional from unittest.mock import Mock +from parameterized import parameterized + import synapse.rest.admin from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes +from synapse.handlers.pagination import PaginationHandler from synapse.rest.client import directory, events, login, room from tests import unittest @@ -68,11 +72,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - json.dumps({}), + {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_room_does_not_exist(self): @@ -84,11 +88,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", url, - json.dumps({}), + {}, access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_room_is_not_valid(self): @@ -100,11 +104,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", url, - json.dumps({}), + {}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom is not a legal room ID", channel.json_body["error"], @@ -119,11 +123,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("new_room_id", channel.json_body) self.assertIn("kicked_users", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -138,11 +142,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "User must be our own: @not:exist.bla", channel.json_body["error"], @@ -157,11 +161,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_is_not_bool(self): @@ -173,11 +177,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_room_and_block(self): @@ -199,11 +203,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -232,11 +236,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -266,11 +270,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -281,6 +285,31 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self._is_blocked(self.room_id, expect=True) self._has_no_members(self.room_id) + @parameterized.expand([(True,), (False,)]) + def test_block_unknown_room(self, purge: bool) -> None: + """ + We can block an unknown room. In this case, the `purge` argument + should be ignored. + """ + room_id = "!unknown:test" + + # The room isn't already in the blocked rooms table + self._is_blocked(room_id, expect=False) + + # Request the room be blocked. + channel = self.make_request( + "DELETE", + f"/_synapse/admin/v1/rooms/{room_id}", + {"block": True, "purge": purge}, + access_token=self.admin_user_tok, + ) + + # The room is now blocked. + self.assertEqual( + HTTPStatus.OK, int(channel.result["code"]), msg=channel.result["body"] + ) + self._is_blocked(room_id) + def test_shutdown_room_consent(self): """Test that we can shutdown rooms with local users who have not yet accepted the privacy policy. This used to fail when we tried to @@ -316,7 +345,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -345,7 +374,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): json.dumps({"history_visibility": "world_readable"}), access_token=self.other_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Test that room is not purged with self.assertRaises(AssertionError): @@ -362,7 +391,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -418,18 +447,617 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + + url = "events?timeout=0&room_id=" + room_id + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + + +class DeleteRoomV2TestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.consent.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = f"/_synapse/admin/v2/rooms/{self.room_id}" + self.url_status_by_room_id = ( + f"/_synapse/admin/v2/rooms/{self.room_id}/delete_status" + ) + self.url_status_by_delete_id = "/_synapse/admin/v2/rooms/delete_status/" + + @parameterized.expand( + [ + ("DELETE", "/_synapse/admin/v2/rooms/%s"), + ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), + ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"), + ] + ) + def test_requester_is_no_admin(self, method: str, url: str): + """ + If the user is not a server admin, an error 403 is returned. + """ + + channel = self.make_request( + method, + url % self.room_id, + content={}, + access_token=self.other_user_tok, + ) + + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + @parameterized.expand( + [ + ("DELETE", "/_synapse/admin/v2/rooms/%s"), + ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), + ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"), + ] + ) + def test_room_does_not_exist(self, method: str, url: str): + """ + Check that unknown rooms/server return error 404. + """ + + channel = self.make_request( + method, + url % "!unknown:test", + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + @parameterized.expand( + [ + ("DELETE", "/_synapse/admin/v2/rooms/%s"), + ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), + ] + ) + def test_room_is_not_valid(self, method: str, url: str): + """ + Check that invalid room names, return an error 400. + """ + + channel = self.make_request( + method, + url % "invalidroom", + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] + "invalidroom is not a legal room ID", + channel.json_body["error"], + ) + + def test_new_room_user_does_not_exist(self): + """ + Tests that the user ID must be from local server but it does not have to exist. + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": "@unknown:test"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user, expect_new_room=True) + + def test_new_room_user_is_not_local(self): + """ + Check that only local users can create new room to move members. + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": "@not:exist.bla"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual( + "User must be our own: @not:exist.bla", + channel.json_body["error"], + ) + + def test_block_is_not_bool(self): + """ + If parameter `block` is not boolean, return an error + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"block": "NotBool"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_purge_is_not_bool(self): + """ + If parameter `purge` is not boolean, return an error + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"purge": "NotBool"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_delete_expired_status(self): + """Test that the task status is removed after expiration.""" + + # first task, do not purge, that we can create a second task + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"purge": False}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id1 = channel.json_body["delete_id"] + + # go ahead + self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + + # second task + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"purge": True}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id2 = channel.json_body["delete_id"] + + # get status + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(2, len(channel.json_body["results"])) + self.assertEqual("complete", channel.json_body["results"][0]["status"]) + self.assertEqual("complete", channel.json_body["results"][1]["status"]) + self.assertEqual(delete_id1, channel.json_body["results"][0]["delete_id"]) + self.assertEqual(delete_id2, channel.json_body["results"][1]["delete_id"]) + + # get status after more than clearing time for first task + # second task is not cleared + self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + self.assertEqual("complete", channel.json_body["results"][0]["status"]) + self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"]) + + # get status after more than clearing time for all tasks + self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_delete_same_room_twice(self): + """Test that the call for delete a room at second time gives an exception.""" + + body = {"new_room_user_id": self.admin_user} + + # first call to delete room + # and do not wait for finish the task + first_channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content=body, + access_token=self.admin_user_tok, + await_result=False, + ) + + # second call to delete room + second_channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content=body, + access_token=self.admin_user_tok, + ) + + self.assertEqual( + HTTPStatus.BAD_REQUEST, second_channel.code, msg=second_channel.json_body + ) + self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"]) + self.assertEqual( + f"History purge already in progress for {self.room_id}", + second_channel.json_body["error"], + ) + + # get result of first call + first_channel.await_result() + self.assertEqual(HTTPStatus.OK, first_channel.code, msg=first_channel.json_body) + self.assertIn("delete_id", first_channel.json_body) + + # check status after finish the task + self._test_result( + first_channel.json_body["delete_id"], + self.other_user, + expect_new_room=True, + ) + + def test_purge_room_and_block(self): + """Test to purge a room and block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"block": True, "purge": True}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=True) + self._has_no_members(self.room_id) + + def test_purge_room_and_not_block(self): + """Test to purge a room and do not block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"block": False, "purge": True}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=False) + self._has_no_members(self.room_id) + + def test_block_room_and_not_purge(self): + """Test to block a room without purging it. + Members will not be moved to a new room and will not receive a message. + The room will not be purged. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"block": True, "purge": False}, + access_token=self.admin_user_tok, ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user) + + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=True) + self._has_no_members(self.room_id) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Assert one user in room + users_in_room = self.get_success(self.store.get_users_in_room(self.room_id)) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 + ) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": self.admin_user}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user, expect_new_room=True) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"], + user_id=self.other_user, + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,) + channel = self.make_request( + "PUT", + url.encode("ascii"), + content={"history_visibility": "world_readable"}, + access_token=self.other_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": self.admin_user}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user, expect_new_room=True) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"], + user_id=self.other_user, + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + # Assert we can no longer peek into the room + self._assert_peek(self.room_id, expect_code=403) + + def _is_blocked(self, room_id: str, expect: bool = True) -> None: + """Assert that the room is blocked or not""" + d = self.store.is_room_blocked(room_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertIsNone(self.get_success(d)) + + def _has_no_members(self, room_id: str) -> None: + """Assert there is now no longer anyone in the room""" + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([], users_in_room) + + def _is_member(self, room_id: str, user_id: str) -> None: + """Test that user is member of the room""" + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertIn(user_id, users_in_room) + + def _is_purged(self, room_id: str) -> None: + """Test that the following tables have been purged of all rows related to the room.""" + for table in PURGE_TABLES: + count = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table=table, + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg=f"Rows not purged in {table}") + + def _assert_peek(self, room_id: str, expect_code: int) -> None: + """Assert that the admin user can (or cannot) peek into the room.""" + + url = f"rooms/{room_id}/initialSync" + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + url = "events?timeout=0&room_id=" + room_id channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + + def _test_result( + self, + delete_id: str, + kicked_user: str, + expect_new_room: bool = False, + ) -> None: + """ + Test that the result is the expected. + Uses both APIs (status by room_id and delete_id) + + Args: + delete_id: id of this purge + kicked_user: a user_id which is kicked from the room + expect_new_room: if we expect that a new room was created + """ + + # get information by room_id + channel_room_id = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual( + HTTPStatus.OK, channel_room_id.code, msg=channel_room_id.json_body + ) + self.assertEqual(1, len(channel_room_id.json_body["results"])) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] + delete_id, channel_room_id.json_body["results"][0]["delete_id"] ) + # get information by delete_id + channel_delete_id = self.make_request( + "GET", + self.url_status_by_delete_id + delete_id, + access_token=self.admin_user_tok, + ) + self.assertEqual( + HTTPStatus.OK, + channel_delete_id.code, + msg=channel_delete_id.json_body, + ) + + # test values that are the same in both responses + for content in [ + channel_room_id.json_body["results"][0], + channel_delete_id.json_body, + ]: + self.assertEqual("complete", content["status"]) + self.assertEqual(kicked_user, content["shutdown_room"]["kicked_users"][0]) + self.assertIn("failed_to_kick_users", content["shutdown_room"]) + self.assertIn("local_aliases", content["shutdown_room"]) + self.assertNotIn("error", content) + + if expect_new_room: + self.assertIsNotNone(content["shutdown_room"]["new_room_id"]) + else: + self.assertIsNone(content["shutdown_room"]["new_room_id"]) + class RoomTestCase(unittest.HomeserverTestCase): """Test /room admin API.""" @@ -466,7 +1094,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) # Check request completed successfully - self.assertEqual(200, int(channel.code), msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that response json body contains a "rooms" key self.assertTrue( @@ -550,9 +1178,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"] - ) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue("rooms" in channel.json_body) for r in channel.json_body["rooms"]: @@ -592,7 +1218,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) def test_correct_room_attributes(self): """Test the correct attributes for a room are returned""" @@ -615,7 +1241,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -647,7 +1273,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that rooms were returned self.assertTrue("rooms" in channel.json_body) @@ -1107,7 +1733,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -1157,11 +1783,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.second_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self): @@ -1173,11 +1799,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) def test_local_user_does_not_exist(self): @@ -1189,11 +1815,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_remote_user(self): @@ -1205,11 +1831,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "This endpoint can only be used with local users", channel.json_body["error"], @@ -1225,11 +1851,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("No known servers", channel.json_body["error"]) def test_room_is_not_valid(self): @@ -1242,11 +1868,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom was not legal room ID or room alias", channel.json_body["error"], @@ -1261,11 +1887,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1275,7 +1901,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_not_member(self): @@ -1292,11 +1918,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_join_private_room_if_member(self): @@ -1324,7 +1950,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) # Join user to room. @@ -1335,10 +1961,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1348,7 +1974,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_owner(self): @@ -1365,11 +1991,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1379,7 +2005,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_context_as_non_admin(self): @@ -1413,9 +2039,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=tok, ) - self.assertEquals( - 403, int(channel.result["code"]), msg=channel.result["body"] - ) + self.assertEquals(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_context_as_admin(self): @@ -1445,7 +2069,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=self.admin_user_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEquals( channel.json_body["event"]["event_id"], events[midway]["event_id"] ) @@ -1504,7 +2128,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) @@ -1531,7 +2155,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room (we should have received an # invite) and can ban a user. @@ -1557,7 +2181,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.second_user_id, tok=self.second_tok) @@ -1595,7 +2219,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): # # (Note we assert the error message to ensure that it's not denied for # some other reason) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body["error"], "No local admin user in room with power to update power levels.", diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 25e8d6cf27..c9fe0f06c2 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -3592,31 +3592,34 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self): + @parameterized.expand(["POST", "DELETE"]) + def test_no_auth(self, method: str): """ Try to get information of an user without authentication. """ - channel = self.make_request("POST", self.url) + channel = self.make_request(method, self.url) self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_not_admin(self): + @parameterized.expand(["POST", "DELETE"]) + def test_requester_is_not_admin(self, method: str): """ If the user is not a server admin, an error is returned. """ other_user_token = self.login("user", "pass") - channel = self.make_request("POST", self.url, access_token=other_user_token) + channel = self.make_request(method, self.url, access_token=other_user_token) self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_is_not_local(self): + @parameterized.expand(["POST", "DELETE"]) + def test_user_is_not_local(self, method: str): """ Tests that shadow-banning for a user that is not a local returns a 400 """ url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" - channel = self.make_request("POST", url, access_token=self.admin_user_tok) + channel = self.make_request(method, url, access_token=self.admin_user_tok) self.assertEqual(400, channel.code, msg=channel.json_body) def test_success(self): @@ -3636,6 +3639,17 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): result = self.get_success(self.store.get_user_by_access_token(other_user_token)) self.assertTrue(result.shadow_banned) + # Un-shadow-ban the user. + channel = self.make_request( + "DELETE", self.url, access_token=self.admin_user_tok + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual({}, channel.json_body) + + # Ensure the user is no longer shadow-banned (and the cache was cleared). + result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + self.assertFalse(result.shadow_banned) + class RateLimitTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py index d2181ea907..aca03afd0e 100644 --- a/tests/rest/client/test_directory.py +++ b/tests/rest/client/test_directory.py @@ -11,12 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor from synapse.rest import admin from synapse.rest.client import directory, login, room +from synapse.server import HomeServer from synapse.types import RoomAlias +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest @@ -32,7 +36,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["require_membership_for_aliases"] = True @@ -40,7 +44,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + """Create two local users and access tokens for them. + One of them creates a room.""" self.room_owner = self.register_user("room_owner", "test") self.room_owner_tok = self.login("room_owner", "test") @@ -51,39 +59,39 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.user = self.register_user("user", "test") self.user_tok = self.login("user", "test") - def test_state_event_not_in_room(self): + def test_state_event_not_in_room(self) -> None: self.ensure_user_left_room() - self.set_alias_via_state_event(403) + self.set_alias_via_state_event(HTTPStatus.FORBIDDEN) - def test_directory_endpoint_not_in_room(self): + def test_directory_endpoint_not_in_room(self) -> None: self.ensure_user_left_room() - self.set_alias_via_directory(403) + self.set_alias_via_directory(HTTPStatus.FORBIDDEN) - def test_state_event_in_room_too_long(self): + def test_state_event_in_room_too_long(self) -> None: self.ensure_user_joined_room() - self.set_alias_via_state_event(400, alias_length=256) + self.set_alias_via_state_event(HTTPStatus.BAD_REQUEST, alias_length=256) - def test_directory_in_room_too_long(self): + def test_directory_in_room_too_long(self) -> None: self.ensure_user_joined_room() - self.set_alias_via_directory(400, alias_length=256) + self.set_alias_via_directory(HTTPStatus.BAD_REQUEST, alias_length=256) @override_config({"default_room_version": 5}) - def test_state_event_user_in_v5_room(self): + def test_state_event_user_in_v5_room(self) -> None: """Test that a regular user can add alias events before room v6""" self.ensure_user_joined_room() - self.set_alias_via_state_event(200) + self.set_alias_via_state_event(HTTPStatus.OK) @override_config({"default_room_version": 6}) - def test_state_event_v6_room(self): + def test_state_event_v6_room(self) -> None: """Test that a regular user can *not* add alias events from room v6""" self.ensure_user_joined_room() - self.set_alias_via_state_event(403) + self.set_alias_via_state_event(HTTPStatus.FORBIDDEN) - def test_directory_in_room(self): + def test_directory_in_room(self) -> None: self.ensure_user_joined_room() - self.set_alias_via_directory(200) + self.set_alias_via_directory(HTTPStatus.OK) - def test_room_creation_too_long(self): + def test_room_creation_too_long(self) -> None: url = "/_matrix/client/r0/createRoom" # We use deliberately a localpart under the length threshold so @@ -93,9 +101,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) - def test_room_creation(self): + def test_room_creation(self) -> None: url = "/_matrix/client/r0/createRoom" # Check with an alias of allowed length. There should already be @@ -106,9 +114,46 @@ class DirectoryTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + def test_deleting_alias_via_directory(self) -> None: + # Add an alias for the room. We must be joined to do so. + self.ensure_user_joined_room() + alias = self.set_alias_via_directory(HTTPStatus.OK) + + # Then try to remove the alias + channel = self.make_request( + "DELETE", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + def test_deleting_nonexistant_alias(self) -> None: + # Check that no alias exists + alias = "#potato:test" + channel = self.make_request( + "GET", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) + self.assertIn("error", channel.json_body, channel.json_body) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body) + + # Then try to remove the alias + channel = self.make_request( + "DELETE", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) + self.assertIn("error", channel.json_body, channel.json_body) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body) - def set_alias_via_state_event(self, expected_code, alias_length=5): + def set_alias_via_state_event( + self, expected_code: HTTPStatus, alias_length: int = 5 + ) -> None: url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( self.room_id, self.hs.hostname, @@ -122,8 +167,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, expected_code, channel.result) - def set_alias_via_directory(self, expected_code, alias_length=5): - url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) + def set_alias_via_directory( + self, expected_code: HTTPStatus, alias_length: int = 5 + ) -> str: + alias = self.random_alias(alias_length) + url = "/_matrix/client/r0/directory/room/%s" % alias data = {"room_id": self.room_id} request_data = json.dumps(data) @@ -131,17 +179,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase): "PUT", url, request_data, access_token=self.user_tok ) self.assertEqual(channel.code, expected_code, channel.result) + return alias - def random_alias(self, length): + def random_alias(self, length: int) -> str: return RoomAlias(random_string(length), self.hs.hostname).to_string() - def ensure_user_left_room(self): + def ensure_user_left_room(self) -> None: self.ensure_membership("leave") - def ensure_user_joined_room(self): + def ensure_user_joined_room(self) -> None: self.ensure_membership("join") - def ensure_membership(self, membership): + def ensure_membership(self, membership: str) -> None: try: if membership == "leave": self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index a63f04bd41..0b90e3f803 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -79,7 +79,10 @@ EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fΓΆ&=o"')] # (possibly experimental) login flows we expect to appear in the list after the normal # ones -ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] +ADDITIONAL_LOGIN_FLOWS = [ + {"type": "m.login.application_service"}, + {"type": "uk.half-shot.msc2778.login.application_service"}, +] class LoginRestServletTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 376853fd65..10a4a4dc5e 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -25,7 +25,12 @@ from urllib import parse as urlparse from twisted.internet import defer import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RelationTypes, +) from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin @@ -2157,6 +2162,153 @@ class LabelsTestCase(unittest.HomeserverTestCase): return event_id +class RelationsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def default_config(self): + config = super().default_config() + config["experimental_features"] = {"msc3440_enabled": True} + return config + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + self.helper.join( + room=self.room_id, user=self.second_user_id, tok=self.second_tok + ) + + self.third_user_id = self.register_user("third", "test") + self.third_tok = self.login("third", "test") + self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) + + # An initial event with a relation from second user. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "Message 1"}, + tok=self.tok, + ) + self.event_id_1 = res["event_id"] + self.helper.send_event( + room_id=self.room_id, + type="m.reaction", + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": self.event_id_1, + "key": "π", + } + }, + tok=self.second_tok, + ) + + # Another event with a relation from third user. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "Message 2"}, + tok=self.tok, + ) + self.event_id_2 = res["event_id"] + self.helper.send_event( + room_id=self.room_id, + type="m.reaction", + content={ + "m.relates_to": { + "rel_type": RelationTypes.REFERENCE, + "event_id": self.event_id_2, + } + }, + tok=self.third_tok, + ) + + # An event with no relations. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "No relations"}, + tok=self.tok, + ) + + def _filter_messages(self, filter: JsonDict) -> List[JsonDict]: + """Make a request to /messages with a filter, returns the chunk of events.""" + channel = self.make_request( + "GET", + "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + return channel.json_body["chunk"] + + def test_filter_relation_senders(self): + # Messages which second user reacted to. + filter = {"io.element.relation_senders": [self.second_user_id]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0]["event_id"], self.event_id_1) + + # Messages which third user reacted to. + filter = {"io.element.relation_senders": [self.third_user_id]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0]["event_id"], self.event_id_2) + + # Messages which either user reacted to. + filter = { + "io.element.relation_senders": [self.second_user_id, self.third_user_id] + } + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 2, chunk) + self.assertCountEqual( + [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] + ) + + def test_filter_relation_type(self): + # Messages which have annotations. + filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0]["event_id"], self.event_id_1) + + # Messages which have references. + filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0]["event_id"], self.event_id_2) + + # Messages which have either annotations or references. + filter = { + "io.element.relation_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 2, chunk) + self.assertCountEqual( + [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] + ) + + def test_filter_relation_senders_and_type(self): + # Messages which second user reacted to. + filter = { + "io.element.relation_senders": [self.second_user_id], + "io.element.relation_types": [RelationTypes.ANNOTATION], + } + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0]["event_id"], self.event_id_1) + + class ContextTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index ec0979850b..1af5e5cee5 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -19,10 +19,21 @@ import json import re import time import urllib.parse -from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union +from typing import ( + Any, + AnyStr, + Dict, + Iterable, + Mapping, + MutableMapping, + Optional, + Tuple, + overload, +) from unittest.mock import patch import attr +from typing_extensions import Literal from twisted.web.resource import Resource from twisted.web.server import Site @@ -45,6 +56,32 @@ class RestHelper: site = attr.ib(type=Site) auth_user_id = attr.ib() + @overload + def create_room_as( + self, + room_creator: Optional[str] = ..., + is_public: Optional[bool] = ..., + room_version: Optional[str] = ..., + tok: Optional[str] = ..., + expect_code: Literal[200] = ..., + extra_content: Optional[Dict] = ..., + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., + ) -> str: + ... + + @overload + def create_room_as( + self, + room_creator: Optional[str] = ..., + is_public: Optional[bool] = ..., + room_version: Optional[str] = ..., + tok: Optional[str] = ..., + expect_code: int = ..., + extra_content: Optional[Dict] = ..., + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., + ) -> Optional[str]: + ... + def create_room_as( self, room_creator: Optional[str] = None, @@ -53,10 +90,8 @@ class RestHelper: tok: Optional[str] = None, expect_code: int = 200, extra_content: Optional[Dict] = None, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, - ) -> str: + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, + ) -> Optional[str]: """ Create a room. @@ -99,6 +134,8 @@ class RestHelper: if expect_code == 200: return channel.json_body["room_id"] + else: + return None def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): self.change_membership( @@ -168,7 +205,7 @@ class RestHelper: extra_data: Optional[dict] = None, tok: Optional[str] = None, expect_code: int = 200, - expect_errcode: str = None, + expect_errcode: Optional[str] = None, ) -> None: """ Send a membership state event into a room. @@ -227,9 +264,7 @@ class RestHelper: txn_id=None, tok=None, expect_code=200, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ): if body is None: body = "body_text_here" @@ -254,9 +289,7 @@ class RestHelper: txn_id=None, tok=None, expect_code=200, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -418,7 +451,7 @@ class RestHelper: path, content=image_data, access_token=tok, - custom_headers=[(b"Content-Length", str(image_length))], + custom_headers=[("Content-Length", str(image_length))], ) assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( @@ -503,7 +536,7 @@ class RestHelper: went. """ - cookies = {} + cookies: Dict[str, str] = {} # if we're doing a ui auth, hit the ui auth redirect endpoint if ui_auth_session_id: @@ -625,7 +658,13 @@ class RestHelper: # hit the redirect url again with the right Host header, which should now issue # a cookie and redirect to the SSO provider. - location = channel.headers.getRawHeaders("Location")[0] + def get_location(channel: FakeChannel) -> str: + location_values = channel.headers.getRawHeaders("Location") + # Keep mypy happy by asserting that location_values is nonempty + assert location_values + return location_values[0] + + location = get_location(channel) parts = urllib.parse.urlsplit(location) channel = make_request( self.hs.get_reactor(), @@ -639,7 +678,7 @@ class RestHelper: assert channel.code == 302 channel.extract_cookies(cookies) - return channel.headers.getRawHeaders("Location")[0] + return get_location(channel) def initiate_sso_ui_auth( self, ui_auth_session_id: str, cookies: MutableMapping[str, str] diff --git a/tests/server.py b/tests/server.py index 103351b487..40cf5b12c3 100644 --- a/tests/server.py +++ b/tests/server.py @@ -16,7 +16,17 @@ import json import logging from collections import deque from io import SEEK_END, BytesIO -from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union +from typing import ( + AnyStr, + Callable, + Dict, + Iterable, + MutableMapping, + Optional, + Tuple, + Type, + Union, +) import attr from typing_extensions import Deque @@ -217,14 +227,12 @@ def make_request( path: Union[bytes, str], content: Union[bytes, str, JsonDict] = b"", access_token: Optional[str] = None, - request: Request = SynapseRequest, + request: Type[Request] = SynapseRequest, shorthand: bool = True, federation_auth_origin: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """ diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index a1ba99ff14..d37736edf8 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -11,19 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor +from synapse.server import HomeServer from synapse.types import UserID +from synapse.util import Clock from tests import unittest class ProfileStoreTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.u_frank = UserID.from_string("@frank:test") - def test_displayname(self): + def test_displayname(self) -> None: self.get_success(self.store.create_profile(self.u_frank.localpart)) self.get_success( @@ -48,7 +51,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): self.get_success(self.store.get_profile_displayname(self.u_frank.localpart)) ) - def test_avatar_url(self): + def test_avatar_url(self) -> None: self.get_success(self.store.create_profile(self.u_frank.localpart)) self.get_success( diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 2873e22ccf..fccab733c0 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -161,6 +161,54 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): ) self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) + def test__null_byte_in_display_name_properly_handled(self): + room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + + res = self.get_success( + self.store.db_pool.simple_select_list( + "room_memberships", + {"user_id": "@alice:test"}, + ["display_name", "event_id"], + ) + ) + # Check that we only got one result back + self.assertEqual(len(res), 1) + + # Check that alice's display name is "alice" + self.assertEqual(res[0]["display_name"], "alice") + + # Grab the event_id to use later + event_id = res[0]["event_id"] + + # Create a profile with the offending null byte in the display name + new_profile = {"displayname": "ali\u0000ce"} + + # Ensure that the change goes smoothly and does not fail due to the null byte + self.helper.change_membership( + room, + self.u_alice, + self.u_alice, + "join", + extra_data=new_profile, + tok=self.t_alice, + ) + + res2 = self.get_success( + self.store.db_pool.simple_select_list( + "room_memberships", + {"user_id": "@alice:test"}, + ["display_name", "event_id"], + ) + ) + # Check that we only have two results + self.assertEqual(len(res2), 2) + + # Filter out the previous event using the event_id we grabbed above + row = [row for row in res2 if row["event_id"] != event_id] + + # Check that alice's display name is now None + self.assertEqual(row[0]["display_name"], None) + class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py new file mode 100644 index 0000000000..ce782c7e1d --- /dev/null +++ b/tests/storage/test_stream.py @@ -0,0 +1,207 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.filtering import Filter +from synapse.events import EventBase +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.types import JsonDict + +from tests.unittest import HomeserverTestCase + + +class PaginationTestCase(HomeserverTestCase): + """ + Test the pre-filtering done in the pagination code. + + This is similar to some of the tests in tests.rest.client.test_rooms but here + we ensure that the filtering done in the database is applied successfully. + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def default_config(self): + config = super().default_config() + config["experimental_features"] = {"msc3440_enabled": True} + return config + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + self.helper.join( + room=self.room_id, user=self.second_user_id, tok=self.second_tok + ) + + self.third_user_id = self.register_user("third", "test") + self.third_tok = self.login("third", "test") + self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) + + # An initial event with a relation from second user. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "Message 1"}, + tok=self.tok, + ) + self.event_id_1 = res["event_id"] + self.helper.send_event( + room_id=self.room_id, + type="m.reaction", + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": self.event_id_1, + "key": "π", + } + }, + tok=self.second_tok, + ) + + # Another event with a relation from third user. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "Message 2"}, + tok=self.tok, + ) + self.event_id_2 = res["event_id"] + self.helper.send_event( + room_id=self.room_id, + type="m.reaction", + content={ + "m.relates_to": { + "rel_type": RelationTypes.REFERENCE, + "event_id": self.event_id_2, + } + }, + tok=self.third_tok, + ) + + # An event with no relations. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "No relations"}, + tok=self.tok, + ) + + def _filter_messages(self, filter: JsonDict) -> List[EventBase]: + """Make a request to /messages with a filter, returns the chunk of events.""" + + from_token = self.get_success( + self.hs.get_event_sources().get_current_token_for_pagination() + ) + + events, next_key = self.get_success( + self.hs.get_datastore().paginate_room_events( + room_id=self.room_id, + from_key=from_token.room_key, + to_key=None, + direction="b", + limit=10, + event_filter=Filter(self.hs, filter), + ) + ) + + return events + + def test_filter_relation_senders(self): + # Messages which second user reacted to. + filter = {"io.element.relation_senders": [self.second_user_id]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0].event_id, self.event_id_1) + + # Messages which third user reacted to. + filter = {"io.element.relation_senders": [self.third_user_id]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0].event_id, self.event_id_2) + + # Messages which either user reacted to. + filter = { + "io.element.relation_senders": [self.second_user_id, self.third_user_id] + } + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 2, chunk) + self.assertCountEqual( + [c.event_id for c in chunk], [self.event_id_1, self.event_id_2] + ) + + def test_filter_relation_type(self): + # Messages which have annotations. + filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0].event_id, self.event_id_1) + + # Messages which have references. + filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0].event_id, self.event_id_2) + + # Messages which have either annotations or references. + filter = { + "io.element.relation_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 2, chunk) + self.assertCountEqual( + [c.event_id for c in chunk], [self.event_id_1, self.event_id_2] + ) + + def test_filter_relation_senders_and_type(self): + # Messages which second user reacted to. + filter = { + "io.element.relation_senders": [self.second_user_id], + "io.element.relation_types": [RelationTypes.ANNOTATION], + } + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0].event_id, self.event_id_1) + + def test_duplicate_relation(self): + """An event should only be returned once if there are multiple relations to it.""" + self.helper.send_event( + room_id=self.room_id, + type="m.reaction", + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": self.event_id_1, + "key": "A", + } + }, + tok=self.second_tok, + ) + + filter = {"io.element.relation_senders": [self.second_user_id]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0].event_id, self.event_id_1) diff --git a/tests/test_federation.py b/tests/test_federation.py index 24fc77d7a7..3eef1c4c05 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -81,8 +81,6 @@ class MessageAcceptTests(unittest.HomeserverTestCase): origin, event, context, - state=None, - backfilled=False, ): return context diff --git a/tests/unittest.py b/tests/unittest.py index a9b60b7eeb..c9a08a3420 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -20,7 +20,20 @@ import inspect import logging import secrets import time -from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union +from typing import ( + Any, + AnyStr, + Callable, + ClassVar, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) from unittest.mock import Mock, patch from canonicaljson import json @@ -31,6 +44,7 @@ from twisted.python.threadpool import ThreadPool from twisted.test.proto_helpers import MemoryReactor from twisted.trial import unittest from twisted.web.resource import Resource +from twisted.web.server import Request from synapse import events from synapse.api.constants import EventTypes, Membership @@ -45,6 +59,7 @@ from synapse.logging.context import ( current_context, set_current_context, ) +from synapse.rest import RegisterServletsFunc from synapse.server import HomeServer from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock @@ -81,16 +96,13 @@ def around(target): return _around -T = TypeVar("T") - - class TestCase(unittest.TestCase): """A subclass of twisted.trial's TestCase which looks for 'loglevel' attributes on both itself and its individual test methods, to override the root logger's logging level while that test (case|method) runs.""" - def __init__(self, methodName, *args, **kwargs): - super().__init__(methodName, *args, **kwargs) + def __init__(self, methodName: str): + super().__init__(methodName) method = getattr(self, methodName) @@ -204,18 +216,18 @@ class HomeserverTestCase(TestCase): config dict. Attributes: - servlets (list[function]): List of servlet registration function. + servlets: List of servlet registration function. user_id (str): The user ID to assume if auth is hijacked. - hijack_auth (bool): Whether to hijack auth to return the user specified + hijack_auth: Whether to hijack auth to return the user specified in user_id. """ - servlets = [] - hijack_auth = True - needs_threadpool = False + hijack_auth: ClassVar[bool] = True + needs_threadpool: ClassVar[bool] = False + servlets: ClassVar[List[RegisterServletsFunc]] = [] - def __init__(self, methodName, *args, **kwargs): - super().__init__(methodName, *args, **kwargs) + def __init__(self, methodName: str): + super().__init__(methodName) # see if we have any additional config for this test method = getattr(self, methodName) @@ -287,9 +299,10 @@ class HomeserverTestCase(TestCase): None, ) - self.hs.get_auth().get_user_by_req = get_user_by_req - self.hs.get_auth().get_user_by_access_token = get_user_by_access_token - self.hs.get_auth().get_access_token_from_request = Mock( + # Type ignore: mypy doesn't like us assigning to methods. + self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment] + self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment] + self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment] return_value="1234" ) @@ -403,14 +416,12 @@ class HomeserverTestCase(TestCase): path: Union[bytes, str], content: Union[bytes, str, JsonDict] = b"", access_token: Optional[str] = None, - request: Type[T] = SynapseRequest, + request: Type[Request] = SynapseRequest, shorthand: bool = True, - federation_auth_origin: str = None, + federation_auth_origin: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """ @@ -425,7 +436,7 @@ class HomeserverTestCase(TestCase): a dict. shorthand: Whether to try and be helpful and prefix the given URL with the usual REST API path, if it doesn't contain it. - federation_auth_origin (bytes|None): if set to not-None, we will add a fake + federation_auth_origin: if set to not-None, we will add a fake Authorization header pretenting to be the given server name. content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. @@ -584,7 +595,7 @@ class HomeserverTestCase(TestCase): nonce_str += b"\x00notadmin" want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str) - want_mac = want_mac.hexdigest() + want_mac_digest = want_mac.hexdigest() body = json.dumps( { @@ -593,7 +604,7 @@ class HomeserverTestCase(TestCase): "displayname": displayname, "password": password, "admin": admin, - "mac": want_mac, + "mac": want_mac_digest, "inhibit_login": True, } ) @@ -639,9 +650,7 @@ class HomeserverTestCase(TestCase): username, password, device_id=None, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ): """ Log in a user, and get an access token. Requires the Login API be |