diff options
Diffstat (limited to 'tests/rest/client')
-rw-r--r-- | tests/rest/client/test_ephemeral_message.py | 101 | ||||
-rw-r--r-- | tests/rest/client/test_power_levels.py | 205 | ||||
-rw-r--r-- | tests/rest/client/test_retention.py | 293 | ||||
-rw-r--r-- | tests/rest/client/test_transactions.py | 16 | ||||
-rw-r--r-- | tests/rest/client/v1/test_events.py | 37 | ||||
-rw-r--r-- | tests/rest/client/v1/test_login.py | 425 | ||||
-rw-r--r-- | tests/rest/client/v1/test_presence.py | 3 | ||||
-rw-r--r-- | tests/rest/client/v1/test_profile.py | 12 | ||||
-rw-r--r-- | tests/rest/client/v1/test_rooms.py | 996 | ||||
-rw-r--r-- | tests/rest/client/v1/test_typing.py | 14 | ||||
-rw-r--r-- | tests/rest/client/v1/utils.py | 153 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_account.py | 386 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_auth.py | 277 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_filter.py | 2 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_password_policy.py | 179 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_register.py | 180 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_sync.py | 183 |
17 files changed, 3315 insertions, 147 deletions
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py new file mode 100644 index 0000000000..5e9c07ebf3 --- /dev/null +++ b/tests/rest/client/test_ephemeral_message.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.constants import EventContentFields, EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import room + +from tests import unittest + + +class EphemeralMessageTestCase(unittest.HomeserverTestCase): + + user_id = "@user:test" + + servlets = [ + admin.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + config["enable_ephemeral_messages"] = True + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_message_expiry_no_delay(self): + """Tests that sending a message sent with a m.self_destruct_after field set to the + past results in that event being deleted right away. + """ + # Send a message in the room that has expired. From here, the reactor clock is + # at 200ms, so 0 is in the past, and even if that wasn't the case and the clock + # is at 0ms the code path is the same if the event's expiry timestamp is the + # current timestamp. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: 0, + }, + ) + event_id = res["event_id"] + + # Check that we can't retrieve the content of the event. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def test_message_expiry_delay(self): + """Tests that sending a message with a m.self_destruct_after field set to the + future results in that event not being deleted right away, but advancing the + clock to after that expiry timestamp causes the event to be deleted. + """ + # Send a message in the room that'll expire in 1s. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: self.clock.time_msec() + 1000, + }, + ) + event_id = res["event_id"] + + # Check that we can retrieve the content of the event before it has expired. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertTrue(bool(event_content), event_content) + + # Advance the clock to after the deletion. + self.reactor.advance(1) + + # Check that we can't retrieve the content of the event anymore. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py new file mode 100644 index 0000000000..913ea3c98e --- /dev/null +++ b/tests/rest/client/test_power_levels.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import sync + +from tests.unittest import HomeserverTestCase + + +class PowerLevelsTestCase(HomeserverTestCase): + """Tests that power levels are enforced in various situations""" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, hs): + # register a room admin, moderator and regular user + self.admin_user_id = self.register_user("admin", "pass") + self.admin_access_token = self.login("admin", "pass") + self.mod_user_id = self.register_user("mod", "pass") + self.mod_access_token = self.login("mod", "pass") + self.user_user_id = self.register_user("user", "pass") + self.user_access_token = self.login("user", "pass") + + # Create a room + self.room_id = self.helper.create_room_as( + self.admin_user_id, tok=self.admin_access_token + ) + + # Invite the other users + self.helper.invite( + room=self.room_id, + src=self.admin_user_id, + tok=self.admin_access_token, + targ=self.mod_user_id, + ) + self.helper.invite( + room=self.room_id, + src=self.admin_user_id, + tok=self.admin_access_token, + targ=self.user_user_id, + ) + + # Make the other users join the room + self.helper.join( + room=self.room_id, user=self.mod_user_id, tok=self.mod_access_token + ) + self.helper.join( + room=self.room_id, user=self.user_user_id, tok=self.user_access_token + ) + + # Mod the mod + room_power_levels = self.helper.get_state( + self.room_id, "m.room.power_levels", tok=self.admin_access_token, + ) + + # Update existing power levels with mod at PL50 + room_power_levels["users"].update({self.mod_user_id: 50}) + + self.helper.send_state( + self.room_id, + "m.room.power_levels", + room_power_levels, + tok=self.admin_access_token, + ) + + def test_non_admins_cannot_enable_room_encryption(self): + # have the mod try to enable room encryption + self.helper.send_state( + self.room_id, + "m.room.encryption", + {"algorithm": "m.megolm.v1.aes-sha2"}, + tok=self.mod_access_token, + expect_code=403, # expect failure + ) + + # have the user try to enable room encryption + self.helper.send_state( + self.room_id, + "m.room.encryption", + {"algorithm": "m.megolm.v1.aes-sha2"}, + tok=self.user_access_token, + expect_code=403, # expect failure + ) + + def test_non_admins_cannot_send_server_acl(self): + # have the mod try to send a server ACL + self.helper.send_state( + self.room_id, + "m.room.server_acl", + { + "allow": ["*"], + "allow_ip_literals": False, + "deny": ["*.evil.com", "evil.com"], + }, + tok=self.mod_access_token, + expect_code=403, # expect failure + ) + + # have the user try to send a server ACL + self.helper.send_state( + self.room_id, + "m.room.server_acl", + { + "allow": ["*"], + "allow_ip_literals": False, + "deny": ["*.evil.com", "evil.com"], + }, + tok=self.user_access_token, + expect_code=403, # expect failure + ) + + def test_non_admins_cannot_tombstone_room(self): + # Create another room that will serve as our "upgraded room" + self.upgraded_room_id = self.helper.create_room_as( + self.admin_user_id, tok=self.admin_access_token + ) + + # have the mod try to send a tombstone event + self.helper.send_state( + self.room_id, + "m.room.tombstone", + { + "body": "This room has been replaced", + "replacement_room": self.upgraded_room_id, + }, + tok=self.mod_access_token, + expect_code=403, # expect failure + ) + + # have the user try to send a tombstone event + self.helper.send_state( + self.room_id, + "m.room.tombstone", + { + "body": "This room has been replaced", + "replacement_room": self.upgraded_room_id, + }, + tok=self.user_access_token, + expect_code=403, # expect failure + ) + + def test_admins_can_enable_room_encryption(self): + # have the admin try to enable room encryption + self.helper.send_state( + self.room_id, + "m.room.encryption", + {"algorithm": "m.megolm.v1.aes-sha2"}, + tok=self.admin_access_token, + expect_code=200, # expect success + ) + + def test_admins_can_send_server_acl(self): + # have the admin try to send a server ACL + self.helper.send_state( + self.room_id, + "m.room.server_acl", + { + "allow": ["*"], + "allow_ip_literals": False, + "deny": ["*.evil.com", "evil.com"], + }, + tok=self.admin_access_token, + expect_code=200, # expect success + ) + + def test_admins_can_tombstone_room(self): + # Create another room that will serve as our "upgraded room" + self.upgraded_room_id = self.helper.create_room_as( + self.admin_user_id, tok=self.admin_access_token + ) + + # have the admin try to send a tombstone event + self.helper.send_state( + self.room_id, + "m.room.tombstone", + { + "body": "This room has been replaced", + "replacement_room": self.upgraded_room_id, + }, + tok=self.admin_access_token, + expect_code=200, # expect success + ) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py new file mode 100644 index 0000000000..95475bb651 --- /dev/null +++ b/tests/rest/client/test_retention.py @@ -0,0 +1,293 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mock import Mock + +from synapse.api.constants import EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.visibility import filter_events_for_client + +from tests import unittest + +one_hour_ms = 3600000 +one_day_ms = one_hour_ms * 24 + + +class RetentionTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["retention"] = { + "enabled": True, + "default_policy": { + "min_lifetime": one_day_ms, + "max_lifetime": one_day_ms * 3, + }, + "allowed_lifetime_min": one_day_ms, + "allowed_lifetime_max": one_day_ms * 3, + } + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def test_retention_state_event(self): + """Tests that the server configuration can limit the values a user can set to the + room's retention policy. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": one_day_ms * 4}, + tok=self.token, + expect_code=400, + ) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": one_hour_ms}, + tok=self.token, + expect_code=400, + ) + + def test_retention_event_purged_with_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by a state event. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set the room's retention period to 2 days. + lifetime = one_day_ms * 2 + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": lifetime}, + tok=self.token, + ) + + self._test_retention_event_purged(room_id, one_day_ms * 1.5) + + def test_retention_event_purged_without_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by the server's configuration's default retention policy. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self._test_retention_event_purged(room_id, one_day_ms * 2) + + def test_visibility(self): + """Tests that synapse.visibility.filter_events_for_client correctly filters out + outdated events + """ + store = self.hs.get_datastore() + storage = self.hs.get_storage() + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + events = [] + + # Send a first event, which should be filtered out at the end of the test. + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) + + # Get the event from the store so that we end up with a FrozenEvent that we can + # give to filter_events_for_client. We need to do this now because the event won't + # be in the database anymore after it has expired. + events.append(self.get_success(store.get_event(resp.get("event_id")))) + + # Advance the time by 2 days. We're using the default retention policy, therefore + # after this the first event will still be valid. + self.reactor.advance(one_day_ms * 2 / 1000) + + # Send another event, which shouldn't get filtered out. + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) + + valid_event_id = resp.get("event_id") + + events.append(self.get_success(store.get_event(valid_event_id))) + + # Advance the time by anothe 2 days. After this, the first event should be + # outdated but not the second one. + self.reactor.advance(one_day_ms * 2 / 1000) + + # Run filter_events_for_client with our list of FrozenEvents. + filtered_events = self.get_success( + filter_events_for_client(storage, self.user_id, events) + ) + + # We should only get one event back. + self.assertEqual(len(filtered_events), 1, filtered_events) + # That event should be the second, not outdated event. + self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) + + def _test_retention_event_purged(self, room_id, increment): + # Get the create event to, later, check that we can still access it. + message_handler = self.hs.get_message_handler() + create_event = self.get_success( + message_handler.get_room_data(self.user_id, room_id, EventTypes.Create) + ) + + # Send a first event to the room. This is the event we'll want to be purged at the + # end of the test. + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) + + expired_event_id = resp.get("event_id") + + # Check that we can retrieve the event. + expired_event = self.get_event(room_id, expired_event_id) + self.assertEqual( + expired_event.get("content", {}).get("body"), "1", expired_event + ) + + # Advance the time. + self.reactor.advance(increment / 1000) + + # Send another event. We need this because the purge job won't purge the most + # recent event in the room. + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) + + valid_event_id = resp.get("event_id") + + # Advance the time again. Now our first event should have expired but our second + # one should still be kept. + self.reactor.advance(increment / 1000) + + # Check that the event has been purged from the database. + self.get_event(room_id, expired_event_id, expected_code=404) + + # Check that the event that hasn't been purged can still be retrieved. + valid_event = self.get_event(room_id, valid_event_id) + self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event) + + # Check that we can still access state events that were sent before the event that + # has been purged. + self.get_event(room_id, create_event.event_id) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url, access_token=self.token) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body + + +class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["retention"] = { + "enabled": True, + } + + mock_federation_client = Mock(spec=["backfill"]) + + self.hs = self.setup_test_homeserver( + config=config, federation_client=mock_federation_client, + ) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def test_no_default_policy(self): + """Tests that an event doesn't get expired if there is neither a default retention + policy nor a policy specific to the room. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self._test_retention(room_id) + + def test_state_policy(self): + """Tests that an event gets correctly expired if there is no default retention + policy but there's a policy specific to the room. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set the maximum lifetime to 35 days so that the first event gets expired but not + # the second one. + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": one_day_ms * 35}, + tok=self.token, + ) + + self._test_retention(room_id, expected_code_for_first_event=404) + + def _test_retention(self, room_id, expected_code_for_first_event=200): + # Send a first event to the room. This is the event we'll want to be purged at the + # end of the test. + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) + + first_event_id = resp.get("event_id") + + # Check that we can retrieve the event. + expired_event = self.get_event(room_id, first_event_id) + self.assertEqual( + expired_event.get("content", {}).get("body"), "1", expired_event + ) + + # Advance the time by a month. + self.reactor.advance(one_day_ms * 30 / 1000) + + # Send another event. We need this because the purge job won't purge the most + # recent event in the room. + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) + + second_event_id = resp.get("event_id") + + # Advance the time by another month. + self.reactor.advance(one_day_ms * 30 / 1000) + + # Check if the event has been purged from the database. + first_event = self.get_event( + room_id, first_event_id, expected_code=expected_code_for_first_event + ) + + if expected_code_for_first_event == 200: + self.assertEqual( + first_event.get("content", {}).get("body"), "1", first_event + ) + + # Check that the event that hasn't been purged can still be retrieved. + second_event = self.get_event(room_id, second_event_id) + self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url, access_token=self.token) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index a3d7e3c046..171632e195 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -2,7 +2,7 @@ from mock import Mock, call from twisted.internet import defer, reactor -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache from synapse.util import Clock @@ -52,14 +52,14 @@ class HttpTransactionCacheTestCase(unittest.TestCase): def test(): with LoggingContext("c") as c1: res = yield self.cache.fetch_or_execute(self.mock_key, cb) - self.assertIs(LoggingContext.current_context(), c1) + self.assertIs(current_context(), c1) self.assertEqual(res, "yay") # run the test twice in parallel d = defer.gatherResults([test(), test()]) - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) yield d - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) @defer.inlineCallbacks def test_does_not_cache_exceptions(self): @@ -81,11 +81,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase): yield self.cache.fetch_or_execute(self.mock_key, cb) except Exception as e: self.assertEqual(e.args[0], "boo") - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) res = yield self.cache.fetch_or_execute(self.mock_key, cb) self.assertEqual(res, self.mock_http_response) - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) @defer.inlineCallbacks def test_does_not_cache_failures(self): @@ -107,11 +107,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase): yield self.cache.fetch_or_execute(self.mock_key, cb) except Exception as e: self.assertEqual(e.args[0], "boo") - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) res = yield self.cache.fetch_or_execute(self.mock_key, cb) self.assertEqual(res, self.mock_http_response) - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) @defer.inlineCallbacks def test_cleans_up(self): diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index f340b7e851..f75520877f 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -15,7 +15,7 @@ """ Tests REST events for /events paths.""" -from mock import Mock, NonCallableMock +from mock import Mock import synapse.rest.admin from synapse.rest.client.v1 import events, login, room @@ -40,17 +40,13 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): config["enable_registration"] = True config["auto_join_rooms"] = [] - hs = self.setup_test_homeserver( - config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"]) - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) + hs = self.setup_test_homeserver(config=config) hs.get_handlers().federation_handler = Mock() return hs - def prepare(self, hs, reactor, clock): + def prepare(self, reactor, clock, hs): # register an account self.user_id = self.register_user("sid1", "pass") @@ -134,3 +130,30 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): # someone else set topic, expect 6 (join,send,topic,join,send,topic) pass + + +class GetEventsTestCase(unittest.HomeserverTestCase): + servlets = [ + events.register_servlets, + room.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def prepare(self, hs, reactor, clock): + + # register an account + self.user_id = self.register_user("sid1", "pass") + self.token = self.login(self.user_id, "pass") + + self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + def test_get_event_via_events(self): + resp = self.helper.send(self.room_id, tok=self.token) + event_id = resp["event_id"] + + request, channel = self.make_request( + "GET", "/events/" + event_id, access_token=self.token, + ) + self.render(request) + self.assertEquals(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index eae5411325..9033f09fd2 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -1,7 +1,13 @@ import json +import time +import urllib.parse + +from mock import Mock + +import jwt import synapse.rest.admin -from synapse.rest.client.v1 import login +from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices from synapse.rest.client.v2_alpha.account import WhoamiRestServlet @@ -17,12 +23,12 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, + logout.register_servlets, devices.register_servlets, lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), ] def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() self.hs.config.enable_registration = True self.hs.config.registrations_require_3pid = [] @@ -31,10 +37,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): return self.hs + @override_config( + { + "rc_login": { + "address": {"per_second": 0.17, "burst_count": 5}, + # Prevent the account login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "account": {"per_second": 10000, "burst_count": 10000}, + } + } + ) def test_POST_ratelimiting_per_address(self): - self.hs.config.rc_login_address.burst_count = 5 - self.hs.config.rc_login_address.per_second = 0.17 - # Create different users so we're sure not to be bothered by the per-user # ratelimiter. for i in range(0, 6): @@ -73,10 +89,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config( + { + "rc_login": { + "account": {"per_second": 0.17, "burst_count": 5}, + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000}, + } + } + ) def test_POST_ratelimiting_per_account(self): - self.hs.config.rc_login_account.burst_count = 5 - self.hs.config.rc_login_account.per_second = 0.17 - self.register_user("kermit", "monkey") for i in range(0, 6): @@ -112,10 +138,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config( + { + "rc_login": { + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000}, + "failed_attempts": {"per_second": 0.17, "burst_count": 5}, + } + } + ) def test_POST_ratelimiting_per_account_failed_attempts(self): - self.hs.config.rc_login_failed_attempts.burst_count = 5 - self.hs.config.rc_login_failed_attempts.per_second = 0.17 - self.register_user("kermit", "monkey") for i in range(0, 6): @@ -252,3 +288,370 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEquals(channel.code, 200, channel.result) + + @override_config({"session_lifetime": "24h"}) + def test_session_can_hard_logout_after_being_soft_logged_out(self): + self.register_user("kermit", "monkey") + + # log in as normal + access_token = self.login("kermit", "monkey") + + # we should now be able to make requests with the access token + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # Now try to hard logout this session + request, channel = self.make_request( + b"POST", "/logout", access_token=access_token + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config({"session_lifetime": "24h"}) + def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): + self.register_user("kermit", "monkey") + + # log in as normal + access_token = self.login("kermit", "monkey") + + # we should now be able to make requests with the access token + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # Now try to hard log out all of the user's sessions + request, channel = self.make_request( + b"POST", "/logout/all", access_token=access_token + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + +class CASTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.base_url = "https://matrix.goodserver.com/" + self.redirect_path = "_synapse/client/login/sso/redirect/confirm" + + config = self.default_config() + config["cas_config"] = { + "enabled": True, + "server_url": "https://fake.test", + "service_url": "https://matrix.goodserver.com:8448", + } + + cas_user_id = "username" + self.user_id = "@%s:test" % cas_user_id + + async def get_raw(uri, args): + """Return an example response payload from a call to the `/proxyValidate` + endpoint of a CAS server, copied from + https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 + + This needs to be returned by an async function (as opposed to set as the + mock's return value) because the corresponding Synapse code awaits on it. + """ + return ( + """ + <cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'> + <cas:authenticationSuccess> + <cas:user>%s</cas:user> + <cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket> + <cas:proxies> + <cas:proxy>https://proxy2/pgtUrl</cas:proxy> + <cas:proxy>https://proxy1/pgtUrl</cas:proxy> + </cas:proxies> + </cas:authenticationSuccess> + </cas:serviceResponse> + """ + % cas_user_id + ) + + mocked_http_client = Mock(spec=["get_raw"]) + mocked_http_client.get_raw.side_effect = get_raw + + self.hs = self.setup_test_homeserver( + config=config, proxied_http_client=mocked_http_client, + ) + + return self.hs + + def prepare(self, reactor, clock, hs): + self.deactivate_account_handler = hs.get_deactivate_account_handler() + + def test_cas_redirect_confirm(self): + """Tests that the SSO login flow serves a confirmation page before redirecting a + user to the redirect URL. + """ + base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl" + redirect_url = "https://dodgy-site.com/" + + url_parts = list(urllib.parse.urlparse(base_url)) + query = dict(urllib.parse.parse_qsl(url_parts[4])) + query.update({"redirectUrl": redirect_url}) + query.update({"ticket": "ticket"}) + url_parts[4] = urllib.parse.urlencode(query) + cas_ticket_url = urllib.parse.urlunparse(url_parts) + + # Get Synapse to call the fake CAS and serve the template. + request, channel = self.make_request("GET", cas_ticket_url) + self.render(request) + + # Test that the response is HTML. + self.assertEqual(channel.code, 200) + content_type_header_value = "" + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type_header_value = header[1].decode("utf8") + + self.assertTrue(content_type_header_value.startswith("text/html")) + + # Test that the body isn't empty. + self.assertTrue(len(channel.result["body"]) > 0) + + # And that it contains our redirect link + self.assertIn(redirect_url, channel.result["body"].decode("UTF-8")) + + @override_config( + { + "sso": { + "client_whitelist": [ + "https://legit-site.com/", + "https://other-site.com/", + ] + } + } + ) + def test_cas_redirect_whitelisted(self): + """Tests that the SSO login flow serves a redirect to a whitelisted url + """ + self._test_redirect("https://legit-site.com/") + + @override_config({"public_baseurl": "https://example.com"}) + def test_cas_redirect_login_fallback(self): + self._test_redirect("https://example.com/_matrix/static/client/login") + + def _test_redirect(self, redirect_url): + """Tests that the SSO login flow serves a redirect for the given redirect URL.""" + cas_ticket_url = ( + "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" + % (urllib.parse.quote(redirect_url)) + ) + + # Get Synapse to call the fake CAS and serve the template. + request, channel = self.make_request("GET", cas_ticket_url) + self.render(request) + + self.assertEqual(channel.code, 302) + location_headers = channel.headers.getRawHeaders("Location") + self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) + + @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) + def test_deactivated_user(self): + """Logging in as a deactivated account should error.""" + redirect_url = "https://legit-site.com/" + + # First login (to create the user). + self._test_redirect(redirect_url) + + # Deactivate the account. + self.get_success( + self.deactivate_account_handler.deactivate_account(self.user_id, False) + ) + + # Request the CAS ticket. + cas_ticket_url = ( + "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" + % (urllib.parse.quote(redirect_url)) + ) + + # Get Synapse to call the fake CAS and serve the template. + request, channel = self.make_request("GET", cas_ticket_url) + self.render(request) + + # Because the user is deactivated they are served an error template. + self.assertEqual(channel.code, 403) + self.assertIn(b"SSO account deactivated", channel.result["body"]) + + +class JWTTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + jwt_secret = "secret" + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.jwt_enabled = True + self.hs.config.jwt_secret = self.jwt_secret + self.hs.config.jwt_algorithm = "HS256" + return self.hs + + def jwt_encode(self, token, secret=jwt_secret): + return jwt.encode(token, secret, "HS256").decode("ascii") + + def jwt_login(self, *args): + params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)}) + request, channel = self.make_request(b"POST", LOGIN_URL, params) + self.render(request) + return channel + + def test_login_jwt_valid_registered(self): + self.register_user("kermit", "monkey") + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + def test_login_jwt_valid_unregistered(self): + channel = self.jwt_login({"sub": "frog"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@frog:test") + + def test_login_jwt_invalid_signature(self): + channel = self.jwt_login({"sub": "frog"}, "notsecret") + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "Invalid JWT") + + def test_login_jwt_expired(self): + channel = self.jwt_login({"sub": "frog", "exp": 864000}) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "JWT expired") + + def test_login_jwt_not_before(self): + now = int(time.time()) + channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "Invalid JWT") + + def test_login_no_sub(self): + channel = self.jwt_login({"username": "root"}) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "Invalid JWT") + + def test_login_no_token(self): + params = json.dumps({"type": "m.login.jwt"}) + request, channel = self.make_request(b"POST", LOGIN_URL, params) + self.render(request) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") + + +# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use +# RSS256, with a public key configured in synapse as "jwt_secret", and tokens +# signed by the private key. +class JWTPubKeyTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + ] + + # This key's pubkey is used as the jwt_secret setting of synapse. Valid + # tokens are signed by this and validated using the pubkey. It is generated + # with `openssl genrsa 512` (not a secure way to generate real keys, but + # good enough for tests!) + jwt_privatekey = "\n".join( + [ + "-----BEGIN RSA PRIVATE KEY-----", + "MIIBPAIBAAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7TKO1vSEWdq7u9x8SMFiB", + "492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQJAUv7OOSOtiU+wzJq82rnk", + "yR4NHqt7XX8BvkZPM7/+EjBRanmZNSp5kYZzKVaZ/gTOM9+9MwlmhidrUOweKfB/", + "kQIhAPZwHazbjo7dYlJs7wPQz1vd+aHSEH+3uQKIysebkmm3AiEA1nc6mDdmgiUq", + "TpIN8A4MBKmfZMWTLq6z05y/qjKyxb0CIQDYJxCwTEenIaEa4PdoJl+qmXFasVDN", + "ZU0+XtNV7yul0wIhAMI9IhiStIjS2EppBa6RSlk+t1oxh2gUWlIh+YVQfZGRAiEA", + "tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=", + "-----END RSA PRIVATE KEY-----", + ] + ) + + # Generated with `openssl rsa -in foo.key -pubout`, with the the above + # private key placed in foo.key (jwt_privatekey). + jwt_pubkey = "\n".join( + [ + "-----BEGIN PUBLIC KEY-----", + "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7", + "TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==", + "-----END PUBLIC KEY-----", + ] + ) + + # This key is used to sign tokens that shouldn't be accepted by synapse. + # Generated just like jwt_privatekey. + bad_privatekey = "\n".join( + [ + "-----BEGIN RSA PRIVATE KEY-----", + "MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv", + "gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L", + "R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY", + "uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I", + "eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb", + "iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0", + "KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m", + "-----END RSA PRIVATE KEY-----", + ] + ) + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.jwt_enabled = True + self.hs.config.jwt_secret = self.jwt_pubkey + self.hs.config.jwt_algorithm = "RS256" + return self.hs + + def jwt_encode(self, token, secret=jwt_privatekey): + return jwt.encode(token, secret, "RS256").decode("ascii") + + def jwt_login(self, *args): + params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)}) + request, channel = self.make_request(b"POST", LOGIN_URL, params) + self.render(request) + return channel + + def test_login_jwt_valid(self): + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + def test_login_jwt_invalid_signature(self): + channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "Invalid JWT") diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 66c2b68707..0fdff79aa7 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -15,6 +15,8 @@ from mock import Mock +from twisted.internet import defer + from synapse.rest.client.v1 import presence from synapse.types import UserID @@ -36,6 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): ) hs.presence_handler = Mock() + hs.presence_handler.set_state.return_value = defer.succeed(None) return hs diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 140d8b3772..8df58b4a63 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -52,6 +52,14 @@ class MockHandlerProfileTestCase(unittest.TestCase): ] ) + self.mock_handler.get_displayname.return_value = defer.succeed(Mock()) + self.mock_handler.set_displayname.return_value = defer.succeed(Mock()) + self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock()) + self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock()) + self.mock_handler.check_profile_query_allowed.return_value = defer.succeed( + Mock() + ) + hs = yield setup_test_homeserver( self.addCleanup, "test", @@ -63,7 +71,7 @@ class MockHandlerProfileTestCase(unittest.TestCase): ) def _get_user_by_req(request=None, allow_guest=False): - return synapse.types.create_requester(myid) + return defer.succeed(synapse.types.create_requester(myid)) hs.get_auth().get_user_by_req = _get_user_by_req @@ -229,6 +237,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): config = self.default_config() config["require_auth_for_profile_requests"] = True + config["limit_profile_requests_to_users_who_share_rooms"] = True self.hs = self.setup_test_homeserver(config=config) return self.hs @@ -301,6 +310,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() config["require_auth_for_profile_requests"] = True + config["limit_profile_requests_to_users_who_share_rooms"] = True self.hs = self.setup_test_homeserver(config=config) return self.hs diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index fe741637f5..4886bbb401 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd # Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,14 +20,18 @@ import json -from mock import Mock, NonCallableMock +from mock import Mock from six.moves.urllib import parse as urlparse from twisted.internet import defer import synapse.rest.admin -from synapse.api.constants import Membership -from synapse.rest.client.v1 import login, profile, room +from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.handlers.pagination import PurgeStatus +from synapse.rest.client.v1 import directory, login, profile, room +from synapse.rest.client.v2_alpha import account +from synapse.types import JsonDict, RoomAlias +from synapse.util.stringutils import random_string from tests import unittest @@ -40,13 +46,8 @@ class RoomBase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), + "red", http_client=None, federation_client=Mock(), ) - self.ratelimiter = self.hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) self.hs.get_federation_handler = Mock(return_value=Mock()) @@ -484,6 +485,15 @@ class RoomsCreateTestCase(RoomBase): self.render(request) self.assertEquals(400, channel.code) + def test_post_room_invitees_invalid_mxid(self): + # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 + # Note the trailing space in the MXID here! + request, channel = self.make_request( + "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' + ) + self.render(request) + self.assertEquals(400, channel.code) + class RoomTopicTestCase(RoomBase): """ Tests /rooms/$room_id/topic REST events. """ @@ -802,6 +812,78 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) + def test_room_messages_purge(self): + store = self.hs.get_datastore() + pagination_handler = self.hs.get_pagination_handler() + + # Send a first message in the room, which will be removed by the purge. + first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] + first_token = self.get_success( + store.get_topological_token_for_event(first_event_id) + ) + + # Send a second message in the room, which won't be removed, and which we'll + # use as the marker to purge events before. + second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] + second_token = self.get_success( + store.get_topological_token_for_event(second_event_id) + ) + + # Send a third event in the room to ensure we don't fall under any edge case + # due to our marker being the latest forward extremity in the room. + self.helper.send(self.room_id, "message 3") + + # Check that we get the first and second message when querying /messages. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) + + # Purge every event before the second event. + purge_id = random_string(16) + pagination_handler._purges_by_id[purge_id] = PurgeStatus() + self.get_success( + pagination_handler._purge_history( + purge_id=purge_id, + room_id=self.room_id, + token=second_token, + delete_local_events=True, + ) + ) + + # Check that we only get the second message through /message now that the first + # has been purged. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) + + # Check that we get no event, but also no error, when querying /messages with + # the token that was pointing at the first event, because we don't have it + # anymore. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) + class RoomSearchTestCase(unittest.HomeserverTestCase): servlets = [ @@ -998,3 +1080,899 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): res_displayname = channel.json_body["content"]["displayname"] self.assertEqual(res_displayname, self.displayname, channel.result) + + +class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): + """Tests that clients can add a "reason" field to membership events and + that they get correctly added to the generated events and propagated. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + + def test_join_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/join".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_leave_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_kick_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/kick".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_ban_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/ban".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_unban_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/unban".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_invite_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/invite".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_reject_invite_reason(self): + self.helper.invite( + self.room_id, + src=self.creator, + targ=self.second_user_id, + tok=self.creator_tok, + ) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def _check_for_reason(self, reason): + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( + self.room_id, self.second_user_id + ), + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + event_content = channel.json_body + + self.assertEqual(event_content.get("reason"), reason, channel.result) + + +class LabelsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + profile.register_servlets, + ] + + # Filter that should only catch messages with the label "#fun". + FILTER_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + # Filter that should only catch messages without the label "#fun". + FILTER_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + # Filter that should only catch messages with the label "#work" but without the label + # "#notfun". + FILTER_LABELS_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_context_filter_labels(self): + """Test that we can filter by a label on a /context request.""" + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "with right label", events_before[0] + ) + + events_after = channel.json_body["events_before"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with right label", events_after[0] + ) + + def test_context_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /context request.""" + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "without label", events_before[0] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual( + len(events_after), 2, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + self.assertEqual( + events_after[1]["content"]["body"], "with two wrong labels", events_after[1] + ) + + def test_context_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /context request. + """ + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 0, [event["content"] for event in events_before] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + + def test_messages_filter_labels(self): + """Test that we can filter by a label on a /messages request.""" + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_messages_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /messages request.""" + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 4, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "without label", events[1]) + self.assertEqual(events[2]["content"]["body"], "with wrong label", events[2]) + self.assertEqual( + events[3]["content"]["body"], "with two wrong labels", events[3] + ) + + def test_messages_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /messages request. + """ + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % ( + self.room_id, + self.tok, + token, + json.dumps(self.FILTER_LABELS_NOT_LABELS), + ), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + + def test_search_filter_labels(self): + """Test that we can filter by a label on a /search request.""" + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), 2, [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with right label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "with right label", + results[1]["result"]["content"]["body"], + ) + + def test_search_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /search request.""" + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), 4, [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "without label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "without label", + results[1]["result"]["content"]["body"], + ) + self.assertEqual( + results[2]["result"]["content"]["body"], + "with wrong label", + results[2]["result"]["content"]["body"], + ) + self.assertEqual( + results[3]["result"]["content"]["body"], + "with two wrong labels", + results[3]["result"]["content"]["body"], + ) + + def test_search_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /search request. + """ + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), 1, [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with wrong label", + results[0]["result"]["content"]["body"], + ) + + def _send_labelled_messages_in_room(self): + """Sends several messages to a room with different labels (or without any) to test + filtering by label. + Returns: + The ID of the event to use if we're testing filtering on /context. + """ + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, + ) + + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, + ) + # Return this event's ID when we test filtering in /context requests. + event_id = res["event_id"] + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=self.tok, + ) + + return event_id + + +class ContextTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + account.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.tok = self.login("user", "password") + self.room_id = self.helper.create_room_as( + self.user_id, tok=self.tok, is_public=False + ) + + self.other_user_id = self.register_user("user2", "password") + self.other_tok = self.login("user2", "password") + + self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok) + self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok) + + def test_erased_sender(self): + """Test that an erasure request results in the requester's events being hidden + from any new member of the room. + """ + + # Send a bunch of events in the room. + + self.helper.send(self.room_id, "message 1", tok=self.tok) + self.helper.send(self.room_id, "message 2", tok=self.tok) + event_id = self.helper.send(self.room_id, "message 3", tok=self.tok)["event_id"] + self.helper.send(self.room_id, "message 4", tok=self.tok) + self.helper.send(self.room_id, "message 5", tok=self.tok) + + # Check that we can still see the messages before the erasure request. + + request, channel = self.make_request( + "GET", + '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' + % (self.room_id, event_id), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual(len(events_before), 2, events_before) + self.assertEqual( + events_before[0].get("content", {}).get("body"), + "message 2", + events_before[0], + ) + self.assertEqual( + events_before[1].get("content", {}).get("body"), + "message 1", + events_before[1], + ) + + self.assertEqual( + channel.json_body["event"].get("content", {}).get("body"), + "message 3", + channel.json_body["event"], + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual(len(events_after), 2, events_after) + self.assertEqual( + events_after[0].get("content", {}).get("body"), + "message 4", + events_after[0], + ) + self.assertEqual( + events_after[1].get("content", {}).get("body"), + "message 5", + events_after[1], + ) + + # Deactivate the first account and erase the user's data. + + deactivate_account_handler = self.hs.get_deactivate_account_handler() + self.get_success( + deactivate_account_handler.deactivate_account(self.user_id, erase_data=True) + ) + + # Invite another user in the room. This is needed because messages will be + # pruned only if the user wasn't a member of the room when the messages were + # sent. + + invited_user_id = self.register_user("user3", "password") + invited_tok = self.login("user3", "password") + + self.helper.invite( + self.room_id, self.other_user_id, invited_user_id, tok=self.other_tok + ) + self.helper.join(self.room_id, invited_user_id, tok=invited_tok) + + # Check that a user that joined the room after the erasure request can't see + # the messages anymore. + + request, channel = self.make_request( + "GET", + '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' + % (self.room_id, event_id), + access_token=invited_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual(len(events_before), 2, events_before) + self.assertDictEqual(events_before[0].get("content"), {}, events_before[0]) + self.assertDictEqual(events_before[1].get("content"), {}, events_before[1]) + + self.assertDictEqual( + channel.json_body["event"].get("content"), {}, channel.json_body["event"] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual(len(events_after), 2, events_after) + self.assertDictEqual(events_after[0].get("content"), {}, events_after[0]) + self.assertEqual(events_after[1].get("content"), {}, events_after[1]) + + +class RoomAliasListTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + def test_no_aliases(self): + res = self._get_aliases(self.room_owner_tok) + self.assertEqual(res["aliases"], []) + + def test_not_in_room(self): + self.register_user("user", "test") + user_tok = self.login("user", "test") + res = self._get_aliases(user_tok, expected_code=403) + self.assertEqual(res["errcode"], "M_FORBIDDEN") + + def test_admin_user(self): + alias1 = self._random_alias() + self._set_alias_via_directory(alias1) + + self.register_user("user", "test", admin=True) + user_tok = self.login("user", "test") + + res = self._get_aliases(user_tok) + self.assertEqual(res["aliases"], [alias1]) + + def test_with_aliases(self): + alias1 = self._random_alias() + alias2 = self._random_alias() + + self._set_alias_via_directory(alias1) + self._set_alias_via_directory(alias2) + + res = self._get_aliases(self.room_owner_tok) + self.assertEqual(set(res["aliases"]), {alias1, alias2}) + + def test_peekable_room(self): + alias1 = self._random_alias() + self._set_alias_via_directory(alias1) + + self.helper.send_state( + self.room_id, + EventTypes.RoomHistoryVisibility, + body={"history_visibility": "world_readable"}, + tok=self.room_owner_tok, + ) + + self.register_user("user", "test") + user_tok = self.login("user", "test") + + res = self._get_aliases(user_tok) + self.assertEqual(res["aliases"], [alias1]) + + def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2432/rooms/%s/aliases" + % (self.room_id,), + access_token=access_token, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + if expected_code == 200: + self.assertIsInstance(res["aliases"], list) + return res + + def _random_alias(self) -> str: + return RoomAlias(random_string(5), self.hs.hostname).to_string() + + def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + url = "/_matrix/client/r0/directory/room/" + alias + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.room_owner_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + +class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + self.alias = "#alias:test" + self._set_alias_via_directory(self.alias) + + def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + url = "/_matrix/client/r0/directory/room/" + alias + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.room_owner_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "GET", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + access_token=self.room_owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "PUT", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + json.dumps(content), + access_token=self.room_owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def test_canonical_alias(self): + """Test a basic alias message.""" + # There is no canonical alias to start with. + self._get_canonical_alias(expected_code=404) + + # Create an alias. + self._set_canonical_alias({"alias": self.alias}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + # Now remove the alias. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alt_aliases(self): + """Test a canonical alias message with alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alias_alt_aliases(self): + """Test a canonical alias message with an alias and alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alias and alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_partial_modify(self): + """Test removing only the alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({"alias": self.alias}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + def test_add_alias(self): + """Test removing only the alt_aliases.""" + # Create an additional alias. + second_alias = "#second:test" + self._set_alias_via_directory(second_alias) + + # Add the canonical alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Then add the second alias. + self._set_canonical_alias( + {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual( + res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + def test_bad_data(self): + """Invalid data for alt_aliases should cause errors.""" + self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": None}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 0}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 1}, expected_code=400) + self._set_canonical_alias({"alt_aliases": False}, expected_code=400) + self._set_canonical_alias({"alt_aliases": True}, expected_code=400) + self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) + + def test_bad_alias(self): + """An alias which does not point to the room raises a SynapseError.""" + self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 30fb77bac8..18260bb90e 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -16,7 +16,7 @@ """Tests REST events for /rooms paths.""" -from mock import Mock, NonCallableMock +from mock import Mock from twisted.internet import defer @@ -39,17 +39,11 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), + "red", http_client=None, federation_client=Mock(), ) self.event_source = hs.get_event_sources().sources["typing"] - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) - hs.get_handlers().federation_handler = Mock() def get_user_by_access_token(token=None, allow_guest=False): @@ -109,7 +103,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) + events = self.get_success( + self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) + ) self.assertEquals( events[0], [ diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index cdded88b7f..22d734e763 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,9 +18,12 @@ import json import time +from typing import Any, Dict, Optional import attr +from twisted.web.resource import Resource + from synapse.api.constants import Membership from tests.server import make_request, render @@ -33,7 +39,7 @@ class RestHelper(object): resource = attr.ib() auth_user_id = attr.ib() - def create_room_as(self, room_creator, is_public=True, tok=None): + def create_room_as(self, room_creator=None, is_public=True, tok=None): temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" @@ -106,13 +112,22 @@ class RestHelper(object): self.auth_user_id = temp_id def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): - if txn_id is None: - txn_id = "m%s" % (str(time.time())) if body is None: body = "body_text_here" - path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id) content = {"msgtype": "m.text", "body": body} + + return self.send_event( + room_id, "m.room.message", content, txn_id, tok, expect_code + ) + + def send_event( + self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200 + ): + if txn_id is None: + txn_id = "m%s" % (str(time.time())) + + path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id) if tok: path = path + "?access_token=%s" % tok @@ -128,7 +143,34 @@ class RestHelper(object): return channel.json_body - def send_state(self, room_id, event_type, body, tok, expect_code=200, state_key=""): + def _read_write_state( + self, + room_id: str, + event_type: str, + body: Optional[Dict[str, Any]], + tok: str, + expect_code: int = 200, + state_key: str = "", + method: str = "GET", + ) -> Dict: + """Read or write some state from a given room + + Args: + room_id: + event_type: The type of state event + body: Body that is sent when making the request. The content of the state event. + If None, the request to the server will have an empty body + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + method: "GET" or "PUT" for reading or writing state, respectively + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % ( room_id, event_type, @@ -137,9 +179,13 @@ class RestHelper(object): if tok: path = path + "?access_token=%s" % tok - request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(body).encode("utf8") - ) + # Set request body if provided + content = b"" + if body is not None: + content = json.dumps(body).encode("utf8") + + request, channel = make_request(self.hs.get_reactor(), method, path, content) + render(request, self.resource, self.hs.get_reactor()) assert int(channel.result["code"]) == expect_code, ( @@ -148,3 +194,94 @@ class RestHelper(object): ) return channel.json_body + + def get_state( + self, + room_id: str, + event_type: str, + tok: str, + expect_code: int = 200, + state_key: str = "", + ): + """Gets some state from a room + + Args: + room_id: + event_type: The type of state event + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + return self._read_write_state( + room_id, event_type, None, tok, expect_code, state_key, method="GET" + ) + + def send_state( + self, + room_id: str, + event_type: str, + body: Dict[str, Any], + tok: str, + expect_code: int = 200, + state_key: str = "", + ): + """Set some state in a room + + Args: + room_id: + event_type: The type of state event + body: Body that is sent when making the request. The content of the state event. + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + return self._read_write_state( + room_id, event_type, body, tok, expect_code, state_key, method="PUT" + ) + + def upload_media( + self, + resource: Resource, + image_data: bytes, + tok: str, + filename: str = "test.png", + expect_code: int = 200, + ) -> dict: + """Upload a piece of test media to the media repo + Args: + resource: The resource that will handle the upload request + image_data: The image data to upload + tok: The user token to use during the upload + filename: The filename of the media to be uploaded + expect_code: The return code to expect from attempting to upload the media + """ + image_length = len(image_data) + path = "/_matrix/media/r0/upload?filename=%s" % (filename,) + request, channel = make_request( + self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok + ) + request.requestHeaders.addRawHeader( + b"Content-Length", str(image_length).encode("UTF-8") + ) + request.render(resource) + self.hs.get_reactor().pump([100]) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + return channel.json_body diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 920de41de4..3ab611f618 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -23,8 +23,9 @@ from email.parser import Parser import pkg_resources import synapse.rest.admin -from synapse.api.constants import LoginType -from synapse.rest.client.v1 import login +from synapse.api.constants import LoginType, Membership +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register from tests import unittest @@ -45,7 +46,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Email config. self.email_attempts = [] - def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): + async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): self.email_attempts.append(msg) return @@ -178,6 +179,22 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the new password self.attempt_wrong_password_login("kermit", new_password) + @unittest.override_config({"request_token_inhibit_3pid_errors": True}) + def test_password_reset_bad_email_inhibit_error(self): + """Test that triggering a password reset with an email address that isn't bound + to an account doesn't leak the lack of binding for that address if configured + that way. + """ + self.register_user("kermit", "monkey") + self.login("kermit", "monkey") + + email = "test@example.com" + + client_secret = "foobar" + session_id = self._request_token(email, client_secret) + + self.assertIsNotNone(session_id) + def _request_token(self, email, client_secret): request, channel = self.make_request( "POST", @@ -244,16 +261,72 @@ class DeactivateTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, account.register_servlets, + room.register_servlets, ] def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver() - return hs + self.hs = self.setup_test_homeserver() + return self.hs def test_deactivate_account(self): user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") + self.deactivate(user_id, tok) + + store = self.hs.get_datastore() + + # Check that the user has been marked as deactivated. + self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) + + # Check that this access token has been invalidated. + request, channel = self.make_request("GET", "account/whoami") + self.render(request) + self.assertEqual(request.code, 401) + + @unittest.INFO + def test_pending_invites(self): + """Tests that deactivating a user rejects every pending invite for them.""" + store = self.hs.get_datastore() + + inviter_id = self.register_user("inviter", "test") + inviter_tok = self.login("inviter", "test") + + invitee_id = self.register_user("invitee", "test") + invitee_tok = self.login("invitee", "test") + + # Make @inviter:test invite @invitee:test in a new room. + room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok) + self.helper.invite( + room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok + ) + + # Make sure the invite is here. + pending_invites = self.get_success( + store.get_invited_rooms_for_local_user(invitee_id) + ) + self.assertEqual(len(pending_invites), 1, pending_invites) + self.assertEqual(pending_invites[0].room_id, room_id, pending_invites) + + # Deactivate @invitee:test. + self.deactivate(invitee_id, invitee_tok) + + # Check that the invite isn't there anymore. + pending_invites = self.get_success( + store.get_invited_rooms_for_local_user(invitee_id) + ) + self.assertEqual(len(pending_invites), 0, pending_invites) + + # Check that the membership of @invitee:test in the room is now "leave". + memberships = self.get_success( + store.get_rooms_for_local_user_where_membership_is( + invitee_id, [Membership.LEAVE] + ) + ) + self.assertEqual(len(memberships), 1, memberships) + self.assertEqual(memberships[0].room_id, room_id, memberships) + + def deactivate(self, user_id, tok): request_data = json.dumps( { "auth": { @@ -270,12 +343,303 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(request.code, 200) - store = self.hs.get_datastore() - # Check that the user has been marked as deactivated. - self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) +class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Email config. + self.email_attempts = [] + + async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): + self.email_attempts.append(msg) + + config["email"] = { + "enable_notifs": False, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "https://example.com" + + self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + return self.hs + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.user_id = self.register_user("kermit", "test") + self.user_id_tok = self.login("kermit", "test") + self.email = "test@example.com" + self.url_3pid = b"account/3pid" + + def test_add_email(self): + """Test adding an email to profile + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) - # Check that this access token has been invalidated. - request, channel = self.make_request("GET", "account/whoami") self.render(request) - self.assertEqual(request.code, 401) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_add_email_if_disabled(self): + """Test adding email to profile when doing so is disallowed + """ + self.hs.config.enable_3pid_changes = False + + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email(self): + """Test deleting an email from profile + """ + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email_if_disabled(self): + """Test deleting an email from profile when disallowed + """ + self.hs.config.enable_3pid_changes = False + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_cant_add_email_without_clicking_link(self): + """Test that we do actually need to click the link in the email + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + + # Attempt to add email without clicking the link + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_no_valid_token(self): + """Test that we do actually need to request a token and can't just + make a session up. + """ + client_secret = "foobar" + session_id = "weasle" + + # Attempt to add email without even requesting an email + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def _request_token(self, email, client_secret): + request, channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + return channel.json_body["sid"] + + def _validate_token(self, link): + # Remove the host + path = link.replace("https://example.com", "") + + request, channel = self.make_request("GET", path, shorthand=False) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + def _get_link_from_email(self): + assert self.email_attempts, "No emails have been sent" + + raw_msg = self.email_attempts[-1].decode("UTF-8") + mail = Parser().parsestr(raw_msg) + + text = None + for part in mail.walk(): + if part.get_content_type() == "text/plain": + text = part.get_payload(decode=True).decode("UTF-8") + break + + if not text: + self.fail("Could not find text portion of email to parse") + + match = re.search(r"https://example.com\S+", text) + assert match, "Could not find link in email" + + return match.group(0) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index b9ef46e8fb..293ccfba2b 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -12,22 +12,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import List, Union from twisted.internet.defer import succeed import synapse.rest.admin from synapse.api.constants import LoginType -from synapse.rest.client.v2_alpha import auth, register +from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker +from synapse.http.site import SynapseRequest +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import auth, devices, register +from synapse.types import JsonDict from tests import unittest +from tests.server import FakeChannel + + +class DummyRecaptchaChecker(UserInteractiveAuthChecker): + def __init__(self, hs): + super().__init__(hs) + self.recaptcha_attempts = [] + + def check_auth(self, authdict, clientip): + self.recaptcha_attempts.append((authdict, clientip)) + return succeed(True) + + +class DummyPasswordChecker(UserInteractiveAuthChecker): + def check_auth(self, authdict, clientip): + return succeed(authdict["identifier"]["user"]) class FallbackAuthTests(unittest.HomeserverTestCase): servlets = [ auth.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, register.register_servlets, ] hijack_auth = False @@ -44,28 +63,55 @@ class FallbackAuthTests(unittest.HomeserverTestCase): return hs def prepare(self, reactor, clock, hs): + self.recaptcha_checker = DummyRecaptchaChecker(hs) auth_handler = hs.get_auth_handler() + auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker - self.recaptcha_attempts = [] + def register(self, expected_response: int, body: JsonDict) -> FakeChannel: + """Make a register request.""" + request, channel = self.make_request( + "POST", "register", body + ) # type: SynapseRequest, FakeChannel + self.render(request) - def _recaptcha(authdict, clientip): - self.recaptcha_attempts.append((authdict, clientip)) - return succeed(True) + self.assertEqual(request.code, expected_response) + return channel - auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha + def recaptcha( + self, session: str, expected_post_response: int, post_session: str = None + ) -> None: + """Get and respond to a fallback recaptcha. Returns the second request.""" + if post_session is None: + post_session = session - @unittest.INFO - def test_fallback_captcha(self): + request, channel = self.make_request( + "GET", "auth/m.login.recaptcha/fallback/web?session=" + session + ) # type: SynapseRequest, FakeChannel + self.render(request) + self.assertEqual(request.code, 200) request, channel = self.make_request( "POST", - "register", - {"username": "user", "type": "m.login.password", "password": "bar"}, + "auth/m.login.recaptcha/fallback/web?session=" + + post_session + + "&g-recaptcha-response=a", ) self.render(request) + self.assertEqual(request.code, expected_post_response) + # The recaptcha handler is called with the response given + attempts = self.recaptcha_checker.recaptcha_attempts + self.assertEqual(len(attempts), 1) + self.assertEqual(attempts[0][0]["response"], "a") + + @unittest.INFO + def test_fallback_captcha(self): + """Ensure that fallback auth via a captcha works.""" # Returns a 401 as per the spec - self.assertEqual(request.code, 401) + channel = self.register( + 401, {"username": "user", "type": "m.login.password", "password": "bar"}, + ) + # Grab the session session = channel.json_body["session"] # Assert our configured public key is being given @@ -73,39 +119,198 @@ class FallbackAuthTests(unittest.HomeserverTestCase): channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" ) - request, channel = self.make_request( - "GET", "auth/m.login.recaptcha/fallback/web?session=" + session + # Complete the recaptcha step. + self.recaptcha(session, 200) + + # also complete the dummy auth + self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) + + # Now we should have fulfilled a complete auth flow, including + # the recaptcha fallback step, we can then send a + # request to the register API with the session in the authdict. + channel = self.register(200, {"auth": {"session": session}}) + + # We're given a registered user. + self.assertEqual(channel.json_body["user_id"], "@user:test") + + def test_complete_operation_unknown_session(self): + """ + Attempting to mark an invalid session as complete should error. + """ + # Make the initial request to register. (Later on a different password + # will be used.) + # Returns a 401 as per the spec + channel = self.register( + 401, {"username": "user", "type": "m.login.password", "password": "bar"} ) + + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + # Attempt to complete the recaptcha step with an unknown session. + # This results in an error. + self.recaptcha(session, 400, session + "unknown") + + +class UIAuthTests(unittest.HomeserverTestCase): + servlets = [ + auth.register_servlets, + devices.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + auth_handler = hs.get_auth_handler() + auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs) + + self.user_pass = "pass" + self.user = self.register_user("test", self.user_pass) + self.user_tok = self.login("test", self.user_pass) + + def get_device_ids(self) -> List[str]: + # Get the list of devices so one can be deleted. + request, channel = self.make_request( + "GET", "devices", access_token=self.user_tok, + ) # type: SynapseRequest, FakeChannel self.render(request) + + # Get the ID of the device. self.assertEqual(request.code, 200) + return [d["device_id"] for d in channel.json_body["devices"]] + def delete_device( + self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b"" + ) -> FakeChannel: + """Delete an individual device.""" request, channel = self.make_request( - "POST", - "auth/m.login.recaptcha/fallback/web?session=" - + session - + "&g-recaptcha-response=a", - ) + "DELETE", "devices/" + device, body, access_token=self.user_tok + ) # type: SynapseRequest, FakeChannel self.render(request) - self.assertEqual(request.code, 200) - # The recaptcha handler is called with the response given - self.assertEqual(len(self.recaptcha_attempts), 1) - self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a") + # Ensure the response is sane. + self.assertEqual(request.code, expected_response) - # also complete the dummy auth + return channel + + def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel: + """Delete 1 or more devices.""" + # Note that this uses the delete_devices endpoint so that we can modify + # the payload half-way through some tests. request, channel = self.make_request( - "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} - ) + "POST", "delete_devices", body, access_token=self.user_tok, + ) # type: SynapseRequest, FakeChannel self.render(request) - # Now we should have fufilled a complete auth flow, including - # the recaptcha fallback step, we can then send a - # request to the register API with the session in the authdict. - request, channel = self.make_request( - "POST", "register", {"auth": {"session": session}} + # Ensure the response is sane. + self.assertEqual(request.code, expected_response) + + return channel + + def test_ui_auth(self): + """ + Test user interactive authentication outside of registration. + """ + device_id = self.get_device_ids()[0] + + # Attempt to delete this device. + # Returns a 401 as per the spec + channel = self.delete_device(device_id, 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + device_id, + 200, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, ) - self.render(request) - self.assertEqual(channel.code, 200) - # We're given a registered user. - self.assertEqual(channel.json_body["user_id"], "@user:test") + def test_can_change_body(self): + """ + The client dict can be modified during the user interactive authentication session. + + Note that it is not spec compliant to modify the client dict during a + user interactive authentication session, but many clients currently do. + + When Synapse is updated to be spec compliant, the call to re-use the + session ID should be rejected. + """ + # Create a second login. + self.login("test", self.user_pass) + + device_ids = self.get_device_ids() + self.assertEqual(len(device_ids), 2) + + # Attempt to delete the first device. + # Returns a 401 as per the spec + channel = self.delete_devices(401, {"devices": [device_ids[0]]}) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow, but try to delete the + # second device. + self.delete_devices( + 200, + { + "devices": [device_ids[1]], + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_cannot_change_uri(self): + """ + The initial requested URI cannot be modified during the user interactive authentication session. + """ + # Create a second login. + self.login("test", self.user_pass) + + device_ids = self.get_device_ids() + self.assertEqual(len(device_ids), 2) + + # Attempt to delete the first device. + # Returns a 401 as per the spec + channel = self.delete_device(device_ids[0], 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow, but try to delete the + # second device. This results in an error. + self.delete_device( + device_ids[1], + 403, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index f42a8efbf4..e0e9e94fbf 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -92,7 +92,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.render(request) - self.assertEqual(channel.result["code"], b"400") + self.assertEqual(channel.result["code"], b"404") self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) # Currently invalid params do not have an appropriate errcode diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py new file mode 100644 index 0000000000..c57072f50c --- /dev/null +++ b/tests/rest/client/v2_alpha/test_password_policy.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.api.constants import LoginType +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import account, password_policy, register + +from tests import unittest + + +class PasswordPolicyTestCase(unittest.HomeserverTestCase): + """Tests the password policy feature and its compliance with MSC2000. + + When validating a password, Synapse does the necessary checks in this order: + + 1. Password is long enough + 2. Password contains digit(s) + 3. Password contains symbol(s) + 4. Password contains uppercase letter(s) + 5. Password contains lowercase letter(s) + + For each test below that checks whether a password triggers the right error code, + that test provides a password good enough to pass the previous tests, but not the + one it is currently testing (nor any test that comes afterward). + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + password_policy.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.register_url = "/_matrix/client/r0/register" + self.policy = { + "enabled": True, + "minimum_length": 10, + "require_digit": True, + "require_symbol": True, + "require_lowercase": True, + "require_uppercase": True, + } + + config = self.default_config() + config["password_config"] = { + "policy": self.policy, + } + + hs = self.setup_test_homeserver(config=config) + return hs + + def test_get_policy(self): + """Tests if the /password_policy endpoint returns the configured policy.""" + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/password_policy" + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual( + channel.json_body, + { + "m.minimum_length": 10, + "m.require_digit": True, + "m.require_symbol": True, + "m.require_lowercase": True, + "m.require_uppercase": True, + }, + channel.result, + ) + + def test_password_too_short(self): + request_data = json.dumps({"username": "kermit", "password": "shorty"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result, + ) + + def test_password_no_digit(self): + request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result, + ) + + def test_password_no_symbol(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result, + ) + + def test_password_no_uppercase(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result, + ) + + def test_password_no_lowercase(self): + request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result, + ) + + def test_password_compliant(self): + request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + # Getting a 401 here means the password has passed validation and the server has + # responded with a list of registration flows. + self.assertEqual(channel.code, 401, channel.result) + + def test_password_change(self): + """This doesn't test every possible use case, only that hitting /account/password + triggers the password validation code. + """ + compliant_password = "C0mpl!antpassword" + not_compliant_password = "notcompliantpassword" + + user_id = self.register_user("kermit", compliant_password) + tok = self.login("kermit", compliant_password) + + request_data = json.dumps( + { + "new_password": not_compliant_password, + "auth": { + "password": compliant_password, + "type": LoginType.PASSWORD, + "user": user_id, + }, + } + ) + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/account/password", + request_data, + access_token=tok, + ) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index ab4d7d70d0..7deaf5b24a 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -25,28 +25,26 @@ import synapse.rest.admin from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import login +from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import account, account_validity, register, sync from tests import unittest +from tests.unittest import override_config class RegisterRestServletTestCase(unittest.HomeserverTestCase): - servlets = [register.register_servlets] - - def make_homeserver(self, reactor, clock): - - self.url = b"/_matrix/client/r0/register" + servlets = [ + login.register_servlets, + register.register_servlets, + synapse.rest.admin.register_servlets, + ] + url = b"/_matrix/client/r0/register" - self.hs = self.setup_test_homeserver() - self.hs.config.enable_registration = True - self.hs.config.registrations_require_3pid = [] - self.hs.config.auto_join_rooms = [] - self.hs.config.enable_registration_captcha = False - self.hs.config.allow_guest_access = True - - return self.hs + def default_config(self): + config = super().default_config() + config["allow_guest_access"] = True + return config def test_POST_appservice_registration_valid(self): user_id = "@as_user_kermit:test" @@ -149,10 +147,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting_guest(self): - self.hs.config.rc_registration.burst_count = 5 - self.hs.config.rc_registration.per_second = 0.17 - for i in range(0, 6): url = self.url + b"?kind=guest" request, channel = self.make_request(b"POST", url, b"{}") @@ -171,10 +167,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self): - self.hs.config.rc_registration.burst_count = 5 - self.hs.config.rc_registration.per_second = 0.17 - for i in range(0, 6): params = { "username": "kermit" + str(i), @@ -199,6 +193,115 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + def test_advertised_flows(self): + request, channel = self.make_request(b"POST", self.url, b"{}") + self.render(request) + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + # with the stock config, we only expect the dummy flow + self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows)) + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "enable_registration_captcha": True, + "user_consent": { + "version": "1", + "template_dir": "/", + "require_at_registration": True, + }, + "account_threepid_delegates": { + "email": "https://id_server", + "msisdn": "https://id_server", + }, + } + ) + def test_advertised_flows_captcha_and_terms_and_3pids(self): + request, channel = self.make_request(b"POST", self.url, b"{}") + self.render(request) + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + self.assertCountEqual( + [ + ["m.login.recaptcha", "m.login.terms", "m.login.dummy"], + ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"], + ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"], + [ + "m.login.recaptcha", + "m.login.terms", + "m.login.msisdn", + "m.login.email.identity", + ], + ], + (f["stages"] for f in flows), + ) + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "registrations_require_3pid": ["email"], + "disable_msisdn_registration": True, + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_advertised_flows_no_msisdn_email_required(self): + request, channel = self.make_request(b"POST", self.url, b"{}") + self.render(request) + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + # with the stock config, we expect all four combinations of 3pid + self.assertCountEqual( + [["m.login.email.identity"]], (f["stages"] for f in flows) + ) + + @unittest.override_config( + { + "request_token_inhibit_3pid_errors": True, + "public_baseurl": "https://test_server", + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_request_token_existing_email_inhibit_error(self): + """Test that requesting a token via this endpoint doesn't leak existing + associations if configured that way. + """ + user_id = self.register_user("kermit", "monkey") + self.login("kermit", "monkey") + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.hs.get_datastore().user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + self.assertIsNotNone(channel.json_body.get("sid")) + class AccountValidityTestCase(unittest.HomeserverTestCase): @@ -207,6 +310,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, sync.register_servlets, + logout.register_servlets, account_validity.register_servlets, ] @@ -299,6 +403,39 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) + def test_logging_out_expired_user(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_matrix/client/unstable/admin/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + request, channel = self.make_request( + b"POST", url, request_data, access_token=admin_tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Try to log the user out + request, channel = self.make_request(b"POST", "/logout", access_token=tok) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Log the user in again (allowed for expired accounts) + tok = self.login("kermit", "monkey") + + # Try to log out all of the user's sessions + request, channel = self.make_request(b"POST", "/logout/all", access_token=tok) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): @@ -330,9 +467,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # Email config. self.email_attempts = [] - def sendmail(*args, **kwargs): + async def sendmail(*args, **kwargs): self.email_attempts.append((args, kwargs)) - return config["email"] = { "enable_notifs": True, diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 71895094bd..fa3a3ec1bd 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from mock import Mock +import json import synapse.rest.admin +from synapse.api.constants import EventContentFields, EventTypes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import sync @@ -26,14 +27,12 @@ from tests.server import TimedOutException class FilterTestCase(unittest.HomeserverTestCase): user_id = "@apple:test" - servlets = [sync.register_servlets] - - def make_homeserver(self, reactor, clock): - - hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock() - ) - return hs + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] def test_sync_argless(self): request, channel = self.make_request("GET", "/sync") @@ -41,16 +40,14 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertTrue( - set( - [ - "next_batch", - "rooms", - "presence", - "account_data", - "to_device", - "device_lists", - ] - ).issubset(set(channel.json_body.keys())) + { + "next_batch", + "rooms", + "presence", + "account_data", + "to_device", + "device_lists", + }.issubset(set(channel.json_body.keys())) ) def test_sync_presence_disabled(self): @@ -64,11 +61,149 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertTrue( - set( - ["next_batch", "rooms", "account_data", "to_device", "device_lists"] - ).issubset(set(channel.json_body.keys())) + { + "next_batch", + "rooms", + "account_data", + "to_device", + "device_lists", + }.issubset(set(channel.json_body.keys())) + ) + + +class SyncFilterTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def test_sync_filter_labels(self): + """Test that we can filter by a label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_sync_filter_not_labels(self): + """Test that we can filter by the absence of a label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 3, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) + self.assertEqual( + events[2]["content"]["body"], "with two wrong labels", events[2] + ) + + def test_sync_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + + def _test_sync_filter_labels(self, sync_filter): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + room_id = self.helper.create_room_as(user_id, tok=tok) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=tok, ) + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + + request, channel = self.make_request( + "GET", "/sync?filter=%s" % sync_filter, access_token=tok + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"] + class SyncTypingTests(unittest.HomeserverTestCase): |