summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_identity.py108
-rw-r--r--tests/handlers/test_profile.py11
-rw-r--r--tests/rest/client/test_identity.py173
-rw-r--r--tests/rest/client/test_retention.py319
-rw-r--r--tests/rest/client/test_room_access_rules.py576
-rw-r--r--tests/rest/client/third_party_rules.py79
-rw-r--r--tests/rest/client/v1/test_profile.py47
-rw-r--r--tests/rest/client/v1/test_rooms.py2
-rw-r--r--tests/rest/client/v2_alpha/test_password_policy.py181
-rw-r--r--tests/rest/client/v2_alpha/test_register.py92
-rw-r--r--tests/rulecheck/__init__.py14
-rw-r--r--tests/rulecheck/test_domainrulecheck.py342
-rw-r--r--tests/storage/test_profile.py13
13 files changed, 1935 insertions, 22 deletions
diff --git a/tests/handlers/test_identity.py b/tests/handlers/test_identity.py
new file mode 100644

index 0000000000..99ce94db52 --- /dev/null +++ b/tests/handlers/test_identity.py
@@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from twisted.internet import defer + +import synapse.rest.admin +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import account + +from tests import unittest + + +class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.address = "test@test" + self.is_server_name = "testis" + self.rewritten_is_url = "int.testis" + + config = self.default_config() + config["trusted_third_party_id_servers"] = [ + self.is_server_name, + ] + config["rewrite_identity_server_urls"] = { + self.is_server_name: self.rewritten_is_url, + } + + mock_http_client = Mock(spec=[ + "post_urlencoded_get_json", + ]) + mock_http_client.post_urlencoded_get_json.return_value = defer.succeed({ + "address": self.address, + "medium": "email", + }) + + self.hs = self.setup_test_homeserver( + config=config, + simple_http_client=mock_http_client, + ) + + return self.hs + + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("kermit", "monkey") + + def test_rewritten_id_server(self): + """ + Tests that, when validating a 3PID association while rewriting the IS's server + name: + * the bind request is done against the rewritten hostname + * the original, non-rewritten, server name is stored in the database + """ + handler = self.hs.get_handlers().identity_handler + post_urlenc_get_json = self.hs.get_simple_http_client().post_urlencoded_get_json + store = self.hs.get_datastore() + + creds = { + "sid": "123", + "client_secret": "some_secret", + } + + # Make sure processing the mocked response goes through. + data = self.get_success(handler.bind_threepid( + { + "id_server": self.is_server_name, + "client_secret": creds["client_secret"], + "sid": creds["sid"], + }, + self.user_id, + )) + self.assertEqual(data.get("address"), self.address) + + # Check that the request was done against the rewritten server name. + post_urlenc_get_json.assert_called_once_with( + "https://%s/_matrix/identity/api/v1/3pid/bind" % self.rewritten_is_url, + { + 'sid': creds['sid'], + 'client_secret': creds["client_secret"], + 'mxid': self.user_id, + } + ) + + # Check that the original server name is saved in the database instead of the + # rewritten one. + id_servers = self.get_success(store.get_id_servers_user_bound( + self.user_id, "email", self.address + )) + self.assertEqual(id_servers, [self.is_server_name]) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index d60c124eec..45cbfeb9a4 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py
@@ -67,13 +67,13 @@ class ProfileTestCase(unittest.TestCase): self.bob = UserID.from_string("@4567:test") self.alice = UserID.from_string("@alice:remote") - yield self.store.create_profile(self.frank.localpart) - self.handler = hs.get_profile_handler() @defer.inlineCallbacks def test_get_my_name(self): - yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + yield self.store.set_profile_displayname( + self.frank.localpart, "Frank", 1, + ) displayname = yield self.handler.get_displayname(self.frank) @@ -116,8 +116,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_incoming_fed_query(self): - yield self.store.create_profile("caroline") - yield self.store.set_profile_displayname("caroline", "Caroline") + yield self.store.set_profile_displayname("caroline", "Caroline", 1) response = yield self.query_handlers["profile"]( {"user_id": "@caroline:test", "field": "displayname"} @@ -128,7 +127,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_avatar(self): yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + self.frank.localpart, "http://my.server/me.png", 1, ) avatar_url = yield self.handler.get_avatar_url(self.frank) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index 68949307d9..c9b9eff83e 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py
@@ -15,15 +15,22 @@ import json +from mock import Mock + +from twisted.internet import defer + import synapse.rest.admin from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import account from tests import unittest -class IdentityTestCase(unittest.HomeserverTestCase): +class IdentityDisabledTestCase(unittest.HomeserverTestCase): + """Tests that 3PID lookup attempts fail when the HS's config disallows them.""" servlets = [ + account.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, @@ -32,19 +39,110 @@ class IdentityTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() + config["trusted_third_party_id_servers"] = [ + "testis", + ] config["enable_3pid_lookup"] = False self.hs = self.setup_test_homeserver(config=config) return self.hs + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + def test_3pid_invite_disabled(self): + request, channel = self.make_request( + b"POST", "/createRoom", b"{}", access_token=self.tok, + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + room_id = channel.json_body["room_id"] + + params = { + "id_server": "testis", + "medium": "email", + "address": "test@example.com", + } + request_data = json.dumps(params) + request_url = ( + "/rooms/%s/invite" % (room_id) + ).encode('ascii') + request, channel = self.make_request( + b"POST", request_url, request_data, access_token=self.tok, + ) + self.render(request) + self.assertEquals(channel.result["code"], b"403", channel.result) + def test_3pid_lookup_disabled(self): - self.hs.config.enable_3pid_lookup = False + url = ("/_matrix/client/unstable/account/3pid/lookup" + "?id_server=testis&medium=email&address=foo@bar.baz") + request, channel = self.make_request("GET", url, access_token=self.tok) + self.render(request) + self.assertEqual(channel.result["code"], b"403", channel.result) + + def test_3pid_bulk_lookup_disabled(self): + url = "/_matrix/client/unstable/account/3pid/bulk_lookup" + data = { + "id_server": "testis", + "threepids": [ + [ + "email", + "foo@bar.baz" + ], + [ + "email", + "john.doe@matrix.org" + ] + ] + } + request_data = json.dumps(data) + request, channel = self.make_request( + "POST", url, request_data, access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.result["code"], b"403", channel.result) - self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") +class IdentityEnabledTestCase(unittest.HomeserverTestCase): + """Tests that 3PID lookup attempts succeed when the HS's config allows them.""" + + servlets = [ + account.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + config = self.default_config() + config["enable_3pid_lookup"] = True + config["trusted_third_party_id_servers"] = [ + "testis", + ] + + mock_http_client = Mock(spec=[ + "get_json", + "post_json_get_json", + ]) + mock_http_client.get_json.return_value = defer.succeed((200, "{}")) + mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}")) + + self.hs = self.setup_test_homeserver( + config=config, + simple_http_client=mock_http_client, + ) + + return self.hs + + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + def test_3pid_invite_enabled(self): request, channel = self.make_request( - b"POST", "/createRoom", b"{}", access_token=tok + b"POST", "/createRoom", b"{}", access_token=self.tok, ) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -58,7 +156,68 @@ class IdentityTestCase(unittest.HomeserverTestCase): request_data = json.dumps(params) request_url = ("/rooms/%s/invite" % (room_id)).encode('ascii') request, channel = self.make_request( - b"POST", request_url, request_data, access_token=tok + b"POST", request_url, request_data, access_token=self.tok, ) self.render(request) - self.assertEquals(channel.result["code"], b"403", channel.result) + + get_json = self.hs.get_simple_http_client().get_json + get_json.assert_called_once_with( + "https://testis/_matrix/identity/api/v1/lookup", + { + "address": "test@example.com", + "medium": "email", + }, + ) + + def test_3pid_lookup_enabled(self): + url = ("/_matrix/client/unstable/account/3pid/lookup" + "?id_server=testis&medium=email&address=foo@bar.baz") + request, channel = self.make_request("GET", url, access_token=self.tok) + self.render(request) + + get_json = self.hs.get_simple_http_client().get_json + get_json.assert_called_once_with( + "https://testis/_matrix/identity/api/v1/lookup", + { + "address": "foo@bar.baz", + "medium": "email", + }, + ) + + def test_3pid_bulk_lookup_enabled(self): + url = "/_matrix/client/unstable/account/3pid/bulk_lookup" + data = { + "id_server": "testis", + "threepids": [ + [ + "email", + "foo@bar.baz" + ], + [ + "email", + "john.doe@matrix.org" + ] + ] + } + request_data = json.dumps(data) + request, channel = self.make_request( + "POST", url, request_data, access_token=self.tok, + ) + self.render(request) + + post_json = self.hs.get_simple_http_client().post_json_get_json + post_json.assert_called_once_with( + "https://testis/_matrix/identity/api/v1/bulk_lookup", + { + "threepids": [ + [ + "email", + "foo@bar.baz" + ], + [ + "email", + "john.doe@matrix.org" + ] + ], + }, + ) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py new file mode 100644
index 0000000000..a040433994 --- /dev/null +++ b/tests/rest/client/test_retention.py
@@ -0,0 +1,319 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mock import Mock + +from synapse.api.constants import EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.visibility import filter_events_for_client + +from tests import unittest + +one_hour_ms = 3600000 +one_day_ms = one_hour_ms * 24 + + +class RetentionTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["default_room_version"] = "1" + config["retention"] = { + "enabled": True, + "default_policy": { + "min_lifetime": one_day_ms, + "max_lifetime": one_day_ms * 3, + }, + "allowed_lifetime_min": one_day_ms, + "allowed_lifetime_max": one_day_ms * 3, + } + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def test_retention_state_event(self): + """Tests that the server configuration can limit the values a user can set to the + room's retention policy. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={ + "max_lifetime": one_day_ms * 4, + }, + tok=self.token, + expect_code=400, + ) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={ + "max_lifetime": one_hour_ms, + }, + tok=self.token, + expect_code=400, + ) + + def test_retention_event_purged_with_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by a state event. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set the room's retention period to 2 days. + lifetime = one_day_ms * 2 + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={ + "max_lifetime": lifetime, + }, + tok=self.token, + ) + + self._test_retention_event_purged(room_id, one_day_ms * 1.5) + + def test_retention_event_purged_without_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by the server's configuration's default retention policy. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self._test_retention_event_purged(room_id, one_day_ms * 2) + + def test_visibility(self): + """Tests that synapse.visibility.filter_events_for_client correctly filters out + outdated events + """ + store = self.hs.get_datastore() + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + events = [] + + # Send a first event, which should be filtered out at the end of the test. + resp = self.helper.send( + room_id=room_id, + body="1", + tok=self.token, + ) + + # Get the event from the store so that we end up with a FrozenEvent that we can + # give to filter_events_for_client. We need to do this now because the event won't + # be in the database anymore after it has expired. + events.append(self.get_success( + store.get_event( + resp.get("event_id") + ) + )) + + # Advance the time by 2 days. We're using the default retention policy, therefore + # after this the first event will still be valid. + self.reactor.advance(one_day_ms * 2 / 1000) + + # Send another event, which shouldn't get filtered out. + resp = self.helper.send( + room_id=room_id, + body="2", + tok=self.token, + ) + + valid_event_id = resp.get("event_id") + + events.append(self.get_success( + store.get_event( + valid_event_id + ) + )) + + # Advance the time by anothe 2 days. After this, the first event should be + # outdated but not the second one. + self.reactor.advance(one_day_ms * 2 / 1000) + + # Run filter_events_for_client with our list of FrozenEvents. + filtered_events = self.get_success(filter_events_for_client( + store, self.user_id, events + )) + + # We should only get one event back. + self.assertEqual(len(filtered_events), 1, filtered_events) + # That event should be the second, not outdated event. + self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) + + def _test_retention_event_purged(self, room_id, increment): + # Send a first event to the room. This is the event we'll want to be purged at the + # end of the test. + resp = self.helper.send( + room_id=room_id, + body="1", + tok=self.token, + ) + + expired_event_id = resp.get("event_id") + + # Check that we can retrieve the event. + expired_event = self.get_event(room_id, expired_event_id) + self.assertEqual(expired_event.get("content", {}).get("body"), "1", expired_event) + + # Advance the time. + self.reactor.advance(increment / 1000) + + # Send another event. We need this because the purge job won't purge the most + # recent event in the room. + resp = self.helper.send( + room_id=room_id, + body="2", + tok=self.token, + ) + + valid_event_id = resp.get("event_id") + + # Advance the time again. Now our first event should have expired but our second + # one should still be kept. + self.reactor.advance(increment / 1000) + + # Check that the event has been purged from the database. + self.get_event(room_id, expired_event_id, expected_code=404) + + # Check that the event that hasn't been purged can still be retrieved. + valid_event = self.get_event(room_id, valid_event_id) + self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url, access_token=self.token) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body + + +class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["default_room_version"] = "1" + config["retention"] = { + "enabled": True, + } + + mock_federation_client = Mock(spec=["backfill"]) + + self.hs = self.setup_test_homeserver( + config=config, + federation_client=mock_federation_client, + ) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def test_no_default_policy(self): + """Tests that an event doesn't get expired if there is neither a default retention + policy nor a policy specific to the room. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self._test_retention(room_id) + + def test_state_policy(self): + """Tests that an event gets correctly expired if there is no default retention + policy but there's a policy specific to the room. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set the maximum lifetime to 35 days so that the first event gets expired but not + # the second one. + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={ + "max_lifetime": one_day_ms * 35, + }, + tok=self.token, + ) + + self._test_retention(room_id, expected_code_for_first_event=404) + + def _test_retention(self, room_id, expected_code_for_first_event=200): + # Send a first event to the room. This is the event we'll want to be purged at the + # end of the test. + resp = self.helper.send( + room_id=room_id, + body="1", + tok=self.token, + ) + + first_event_id = resp.get("event_id") + + # Check that we can retrieve the event. + expired_event = self.get_event(room_id, first_event_id) + self.assertEqual(expired_event.get("content", {}).get("body"), "1", expired_event) + + # Advance the time by a month. + self.reactor.advance(one_day_ms * 30 / 1000) + + # Send another event. We need this because the purge job won't purge the most + # recent event in the room. + resp = self.helper.send( + room_id=room_id, + body="2", + tok=self.token, + ) + + second_event_id = resp.get("event_id") + + # Advance the time by another month. + self.reactor.advance(one_day_ms * 30 / 1000) + + # Check if the event has been purged from the database. + first_event = self.get_event( + room_id, first_event_id, expected_code=expected_code_for_first_event + ) + + if expected_code_for_first_event == 200: + self.assertEqual(first_event.get("content", {}).get("body"), "1", first_event) + + # Check that the event that hasn't been purged can still be retrieved. + second_event = self.get_event(room_id, second_event_id) + self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url, access_token=self.token) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py new file mode 100644
index 0000000000..7e23add6b7 --- /dev/null +++ b/tests/rest/client/test_room_access_rules.py
@@ -0,0 +1,576 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import random +import string + +from mock import Mock + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.third_party_rules.access_rules import ( + ACCESS_RULE_DIRECT, + ACCESS_RULE_RESTRICTED, + ACCESS_RULE_UNRESTRICTED, + ACCESS_RULES_TYPE, +) + +from tests import unittest + + +class RoomAccessTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + config["third_party_event_rules"] = { + "module": "synapse.third_party_rules.access_rules.RoomAccessRules", + "config": { + "domains_forbidden_when_restricted": [ + "forbidden_domain" + ], + "id_server": "testis", + } + } + config["trusted_third_party_id_servers"] = [ + "testis", + ] + + def send_invite(destination, room_id, event_id, pdu): + return defer.succeed(pdu) + + def get_json(uri, args={}, headers=None): + address_domain = args["address"].split("@")[1] + return defer.succeed({"hs": address_domain}) + + def post_urlencoded_get_json(uri, args={}, headers=None): + token = ''.join(random.choice(string.ascii_letters) for _ in range(10)) + return defer.succeed({ + "token": token, + "public_keys": [ + { + "public_key": "serverpublickey", + "key_validity_url": "https://testis/pubkey/isvalid", + }, + { + "public_key": "phemeralpublickey", + "key_validity_url": "https://testis/pubkey/ephemeral/isvalid", + }, + ], + "display_name": "f...@b...", + }) + + mock_federation_client = Mock(spec=[ + "send_invite", + ]) + mock_federation_client.send_invite.side_effect = send_invite + + mock_http_client = Mock(spec=[ + "get_json", + "post_urlencoded_get_json" + ]) + # Mocking the response for /info on the IS API. + mock_http_client.get_json.side_effect = get_json + # Mocking the response for /store-invite on the IS API. + mock_http_client.post_urlencoded_get_json.side_effect = post_urlencoded_get_json + self.hs = self.setup_test_homeserver( + config=config, + federation_client=mock_federation_client, + simple_http_client=mock_http_client, + ) + + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + self.restricted_room = self.create_room() + self.unrestricted_room = self.create_room(rule=ACCESS_RULE_UNRESTRICTED) + self.direct_rooms = [ + self.create_room(direct=True), + self.create_room(direct=True), + self.create_room(direct=True), + ] + + self.invitee_id = self.register_user("invitee", "test") + self.invitee_tok = self.login("invitee", "test") + + self.helper.invite( + room=self.direct_rooms[0], + src=self.user_id, + targ=self.invitee_id, + tok=self.tok, + ) + + def test_create_room_no_rule(self): + """Tests that creating a room with no rule will set the default value.""" + room_id = self.create_room() + rule = self.current_rule_in_room(room_id) + + self.assertEqual(rule, ACCESS_RULE_RESTRICTED) + + def test_create_room_direct_no_rule(self): + """Tests that creating a direct room with no rule will set the default value.""" + room_id = self.create_room(direct=True) + rule = self.current_rule_in_room(room_id) + + self.assertEqual(rule, ACCESS_RULE_DIRECT) + + def test_create_room_valid_rule(self): + """Tests that creating a room with a valid rule will set the right value.""" + room_id = self.create_room(rule=ACCESS_RULE_UNRESTRICTED) + rule = self.current_rule_in_room(room_id) + + self.assertEqual(rule, ACCESS_RULE_UNRESTRICTED) + + def test_create_room_invalid_rule(self): + """Tests that creating a room with an invalid rule will set fail.""" + self.create_room(rule=ACCESS_RULE_DIRECT, expected_code=400) + + def test_create_room_direct_invalid_rule(self): + """Tests that creating a direct room with an invalid rule will fail. + """ + self.create_room(direct=True, rule=ACCESS_RULE_RESTRICTED, expected_code=400) + + def test_public_room(self): + """Tests that it's not possible to have a room with the public join rule and an + access rule that's not restricted. + """ + # Creating a room with the public_chat preset should succeed and set the access + # rule to restricted. + preset_room_id = self.create_room(preset=RoomCreationPreset.PUBLIC_CHAT) + self.assertEqual( + self.current_rule_in_room(preset_room_id), ACCESS_RULE_RESTRICTED, + ) + + # Creating a room with the public join rule in its initial state should succeed + # and set the access rule to restricted. + init_state_room_id = self.create_room(initial_state=[{ + "type": "m.room.join_rules", + "content": { + "join_rule": JoinRules.PUBLIC, + }, + }]) + self.assertEqual( + self.current_rule_in_room(init_state_room_id), ACCESS_RULE_RESTRICTED, + ) + + # Changing access rule to unrestricted should fail. + self.change_rule_in_room( + preset_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403, + ) + self.change_rule_in_room( + init_state_room_id, ACCESS_RULE_UNRESTRICTED, expected_code=403, + ) + + # Changing access rule to direct should fail. + self.change_rule_in_room( + preset_room_id, ACCESS_RULE_DIRECT, expected_code=403, + ) + self.change_rule_in_room( + init_state_room_id, ACCESS_RULE_DIRECT, expected_code=403, + ) + + # Changing join rule to public in an unrestricted room should fail. + self.change_join_rule_in_room( + self.unrestricted_room, JoinRules.PUBLIC, expected_code=403, + ) + # Changing join rule to public in an direct room should fail. + self.change_join_rule_in_room( + self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403, + ) + + # Creating a new room with the public_chat preset and an access rule that isn't + # restricted should fail. + self.create_room( + preset=RoomCreationPreset.PUBLIC_CHAT, rule=ACCESS_RULE_UNRESTRICTED, + expected_code=400, + ) + self.create_room( + preset=RoomCreationPreset.PUBLIC_CHAT, rule=ACCESS_RULE_DIRECT, + expected_code=400, + ) + + # Creating a room with the public join rule in its initial state and an access + # rule that isn't restricted should fail. + self.create_room( + initial_state=[{ + "type": "m.room.join_rules", + "content": { + "join_rule": JoinRules.PUBLIC, + }, + }], rule=ACCESS_RULE_UNRESTRICTED, expected_code=400, + ) + self.create_room( + initial_state=[{ + "type": "m.room.join_rules", + "content": { + "join_rule": JoinRules.PUBLIC, + }, + }], rule=ACCESS_RULE_DIRECT, expected_code=400, + ) + + def test_restricted(self): + """Tests that in restricted mode we're unable to invite users from blacklisted + servers but can invite other users. + """ + # We can't invite a user from a forbidden HS. + self.helper.invite( + room=self.restricted_room, + src=self.user_id, + targ="@test:forbidden_domain", + tok=self.tok, + expect_code=403, + ) + + # We can invite a user which HS isn't forbidden. + self.helper.invite( + room=self.restricted_room, + src=self.user_id, + targ="@test:allowed_domain", + tok=self.tok, + expect_code=200, + ) + + # We can't send a 3PID invite to an address that is mapped to a forbidden HS. + self.send_threepid_invite( + address="test@forbidden_domain", + room_id=self.restricted_room, + expected_code=403, + ) + + # We can send a 3PID invite to an address that is mapped to an HS that's not + # forbidden. + self.send_threepid_invite( + address="test@allowed_domain", + room_id=self.restricted_room, + expected_code=200, + ) + + def test_direct(self): + """Tests that, in direct mode, other users than the initial two can't be invited, + but the following scenario works: + * invited user joins the room + * invited user leaves the room + * room creator re-invites invited user + Also tests that a user from a HS that's in the list of forbidden domains (to use + in restricted mode) can be invited. + """ + not_invited_user = "@not_invited:forbidden_domain" + + # We can't invite a new user to the room. + self.helper.invite( + room=self.direct_rooms[0], + src=self.user_id, + targ=not_invited_user, + tok=self.tok, + expect_code=403, + ) + + # The invited user can join the room. + self.helper.join( + room=self.direct_rooms[0], + user=self.invitee_id, + tok=self.invitee_tok, + expect_code=200, + ) + + # The invited user can leave the room. + self.helper.leave( + room=self.direct_rooms[0], + user=self.invitee_id, + tok=self.invitee_tok, + expect_code=200, + ) + + # The invited user can be re-invited to the room. + self.helper.invite( + room=self.direct_rooms[0], + src=self.user_id, + targ=self.invitee_id, + tok=self.tok, + expect_code=200, + ) + + # If we're alone in the room and have always been the only member, we can invite + # someone. + self.helper.invite( + room=self.direct_rooms[1], + src=self.user_id, + targ=not_invited_user, + tok=self.tok, + expect_code=200, + ) + + # We can't send a 3PID invite to a room that already has two members. + self.send_threepid_invite( + address="test@allowed_domain", + room_id=self.direct_rooms[0], + expected_code=403, + ) + + # We can't send a 3PID invite to a room that already has a pending invite. + self.send_threepid_invite( + address="test@allowed_domain", + room_id=self.direct_rooms[1], + expected_code=403, + ) + + # We can send a 3PID invite to a room in which we've always been the only member. + self.send_threepid_invite( + address="test@forbidden_domain", + room_id=self.direct_rooms[2], + expected_code=200, + ) + + # We can send a 3PID invite to a room in which there's a 3PID invite. + self.send_threepid_invite( + address="test@forbidden_domain", + room_id=self.direct_rooms[2], + expected_code=403, + ) + + def test_unrestricted(self): + """Tests that, in unrestricted mode, we can invite whoever we want, but we can + only change the power level of users that wouldn't be forbidden in restricted + mode. + """ + # We can invite + self.helper.invite( + room=self.unrestricted_room, + src=self.user_id, + targ="@test:forbidden_domain", + tok=self.tok, + expect_code=200, + ) + + self.helper.invite( + room=self.unrestricted_room, + src=self.user_id, + targ="@test:not_forbidden_domain", + tok=self.tok, + expect_code=200, + ) + + # We can send a 3PID invite to an address that is mapped to a forbidden HS. + self.send_threepid_invite( + address="test@forbidden_domain", + room_id=self.unrestricted_room, + expected_code=200, + ) + + # We can send a 3PID invite to an address that is mapped to an HS that's not + # forbidden. + self.send_threepid_invite( + address="test@allowed_domain", + room_id=self.unrestricted_room, + expected_code=200, + ) + + # We can send a power level event that doesn't redefine the default PL or set a + # non-default PL for a user that would be forbidden in restricted mode. + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.PowerLevels, + body={ + "users": { + self.user_id: 100, + "@test:not_forbidden_domain": 10, + }, + }, + tok=self.tok, + expect_code=200, + ) + + # We can't send a power level event that redefines the default PL and doesn't set + # a non-default PL for a user that would be forbidden in restricted mode. + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.PowerLevels, + body={ + "users": { + self.user_id: 100, + "@test:not_forbidden_domain": 10, + }, + "users_default": 10, + }, + tok=self.tok, + expect_code=403, + ) + + # We can't send a power level event that doesn't redefines the default PL but sets + # a non-default PL for a user that would be forbidden in restricted mode. + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.PowerLevels, + body={ + "users": { + self.user_id: 100, + "@test:forbidden_domain": 10, + }, + }, + tok=self.tok, + expect_code=403, + ) + + def test_change_rules(self): + """Tests that we can only change the current rule from restricted to + unrestricted. + """ + # We can change the rule from restricted to unrestricted. + self.change_rule_in_room( + room_id=self.restricted_room, + new_rule=ACCESS_RULE_UNRESTRICTED, + expected_code=200, + ) + + # We can't change the rule from restricted to direct. + self.change_rule_in_room( + room_id=self.restricted_room, + new_rule=ACCESS_RULE_DIRECT, + expected_code=403, + ) + + # We can't change the rule from unrestricted to restricted. + self.change_rule_in_room( + room_id=self.unrestricted_room, + new_rule=ACCESS_RULE_RESTRICTED, + expected_code=403, + ) + + # We can't change the rule from unrestricted to direct. + self.change_rule_in_room( + room_id=self.unrestricted_room, + new_rule=ACCESS_RULE_DIRECT, + expected_code=403, + ) + + # We can't change the rule from direct to restricted. + self.change_rule_in_room( + room_id=self.direct_rooms[0], + new_rule=ACCESS_RULE_RESTRICTED, + expected_code=403, + ) + + # We can't change the rule from direct to unrestricted. + self.change_rule_in_room( + room_id=self.direct_rooms[0], + new_rule=ACCESS_RULE_UNRESTRICTED, + expected_code=403, + ) + + def create_room( + self, direct=False, rule=None, preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT, + initial_state=None, expected_code=200, + ): + content = { + "is_direct": direct, + "preset": preset, + } + + if rule: + content["initial_state"] = [{ + "type": ACCESS_RULES_TYPE, + "state_key": "", + "content": { + "rule": rule, + } + }] + + if initial_state: + if "initial_state" not in content: + content["initial_state"] = [] + + content["initial_state"] += initial_state + + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/createRoom", + json.dumps(content), + access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + if expected_code == 200: + return channel.json_body["room_id"] + + def current_rule_in_room(self, room_id): + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE), + access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body["rule"] + + def change_rule_in_room(self, room_id, new_rule, expected_code=200): + data = { + "rule": new_rule, + } + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE), + json.dumps(data), + access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200): + data = { + "join_rule": new_join_rule, + } + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules), + json.dumps(data), + access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + def send_threepid_invite(self, address, room_id, expected_code=200): + params = { + "id_server": "testis", + "medium": "email", + "address": address, + } + + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/invite" % room_id, + json.dumps(params), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py new file mode 100644
index 0000000000..7167fc56b6 --- /dev/null +++ b/tests/rest/client/third_party_rules.py
@@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests import unittest + + +class ThirdPartyRulesTestModule(object): + def __init__(self, config): + pass + + def check_event_allowed(self, event, context): + if event.type == "foo.bar.forbidden": + return False + else: + return True + + @staticmethod + def parse_config(config): + return config + + +class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["third_party_event_rules"] = { + "module": "tests.rest.client.third_party_rules.ThirdPartyRulesTestModule", + "config": {}, + } + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def test_third_party_rules(self): + """Tests that a forbidden event is forbidden from being sent, but an allowed one + can be sent. + """ + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + room_id = self.helper.create_room_as(user_id, tok=tok) + + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % room_id, + {}, + access_token=tok, + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % room_id, + {}, + access_token=tok, + ) + self.render(request) + self.assertEquals(channel.result["code"], b"403", channel.result) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 72c7ed93cb..d932dd3c06 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py
@@ -289,3 +289,50 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): # if the user isn't already in the room), because we only want to # make sure the user isn't in the room. pass + + +class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["require_auth_for_profile_requests"] = True + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, hs): + # User requesting the profile. + self.requester = self.register_user("requester", "pass") + self.requester_tok = self.login("requester", "pass") + + def test_can_lookup_own_profile(self): + """Tests that a user can lookup their own profile without having to be in a room + if 'require_auth_for_profile_requests' is set to true in the server's config. + """ + request, channel = self.make_request( + "GET", "/profile/" + self.requester, access_token=self.requester_tok + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + request, channel = self.make_request( + "GET", + "/profile/" + self.requester + "/displayname", + access_token=self.requester_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + request, channel = self.make_request( + "GET", + "/profile/" + self.requester + "/avatar_url", + access_token=self.requester_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 5f75ad7579..2d64b338be 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py
@@ -920,7 +920,7 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): self.url = b"/_matrix/client/r0/publicRooms" config = self.default_config() - config["restrict_public_rooms_to_local_users"] = True + config["allow_public_rooms_without_auth"] = False self.hs = self.setup_test_homeserver(config=config) return self.hs diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py new file mode 100644
index 0000000000..17c22fe751 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.api.constants import LoginType +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import account, password_policy, register + +from tests import unittest + + +class PasswordPolicyTestCase(unittest.HomeserverTestCase): + """Tests the password policy feature and its compliance with MSC2000. + + When validating a password, Synapse does the necessary checks in this order: + + 1. Password is long enough + 2. Password contains digit(s) + 3. Password contains symbol(s) + 4. Password contains uppercase letter(s) + 5. Password contains lowercase letter(s) + + Therefore, each test in this test case that tests whether a password triggers the + right error code to be returned provides a password good enough to pass the previous + steps but not the one it's testing (nor any step that comes after). + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + password_policy.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.register_url = "/_matrix/client/r0/register" + self.policy = { + "enabled": True, + "minimum_length": 10, + "require_digit": True, + "require_symbol": True, + "require_lowercase": True, + "require_uppercase": True, + } + + config = self.default_config() + config["password_config"] = { + "policy": self.policy, + } + + hs = self.setup_test_homeserver(config=config) + return hs + + def test_get_policy(self): + """Tests if the /password_policy endpoint returns the configured policy.""" + + request, channel = self.make_request("GET", "/_matrix/client/r0/password_policy") + self.render(request) + + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body, { + "m.minimum_length": 10, + "m.require_digit": True, + "m.require_symbol": True, + "m.require_lowercase": True, + "m.require_uppercase": True, + }, channel.result) + + def test_password_too_short(self): + request_data = json.dumps({"username": "kermit", "password": "shorty"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_TOO_SHORT, + channel.result, + ) + + def test_password_no_digit(self): + request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_NO_DIGIT, + channel.result, + ) + + def test_password_no_symbol(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_NO_SYMBOL, + channel.result, + ) + + def test_password_no_uppercase(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_NO_UPPERCASE, + channel.result, + ) + + def test_password_no_lowercase(self): + request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_NO_LOWERCASE, + channel.result, + ) + + def test_password_compliant(self): + request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + # Getting a 401 here means the password has passed validation and the server has + # responded with a list of registration flows. + self.assertEqual(channel.code, 401, channel.result) + + def test_password_change(self): + """This doesn't test every possible use case, only that hitting /account/password + triggers the password validation code. + """ + compliant_password = "C0mpl!antpassword" + not_compliant_password = "notcompliantpassword" + + user_id = self.register_user("kermit", compliant_password) + tok = self.login("kermit", compliant_password) + + request_data = json.dumps({ + "new_password": not_compliant_password, + "auth": { + "password": compliant_password, + "type": LoginType.PASSWORD, + "user": user_id, + } + }) + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/account/password", + request_data, + access_token=tok, + ) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index b35b215446..b28de3663c 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py
@@ -19,8 +19,13 @@ import datetime import json import os +from mock import Mock +from six import ensure_binary + import pkg_resources +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import LoginType from synapse.api.errors import Codes @@ -200,6 +205,53 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) +class RegisterHideProfileTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor, clock): + + self.url = b"/_matrix/client/r0/register" + + config = self.default_config() + config["enable_registration"] = True + config["show_users_in_user_directory"] = False + config["replicate_user_profiles_to"] = ["fakeserver"] + + mock_http_client = Mock(spec=[ + "get_json", + "post_json_get_json", + ]) + mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}")) + + self.hs = self.setup_test_homeserver( + config=config, + simple_http_client=mock_http_client, + ) + + return self.hs + + def test_profile_hidden(self): + user_id = self.register_user("kermit", "monkey") + + post_json = self.hs.get_simple_http_client().post_json_get_json + + # We expect post_json_get_json to have been called twice: once with the original + # profile and once with the None profile resulting from the request to hide it + # from the user directory. + self.assertEqual(post_json.call_count, 2, post_json.call_args_list) + + # Get the args (and not kwargs) passed to post_json. + args = post_json.call_args[0] + # Make sure the last call was attempting to replicate profiles. + split_uri = args[0].split("/") + self.assertEqual(split_uri[len(split_uri) - 1], "replicate_profiles", args[0]) + # Make sure the last profile update was overriding the user's profile to None. + self.assertEqual(args[1]["batch"][user_id], None, args[1]) + + class AccountValidityTestCase(unittest.HomeserverTestCase): servlets = [ @@ -208,6 +260,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): login.register_servlets, sync.register_servlets, account_validity.register_servlets, + account.register_servlets, ] def make_homeserver(self, reactor, clock): @@ -323,6 +376,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "renew_at": 172800000, # Time in ms for 2 days "renew_by_email_enabled": True, "renew_email_subject": "Renew your account", + "account_renewed_html_path": "account_renewed.html", + "invalid_token_html_path": "invalid_token.html", } # Email config. @@ -373,6 +428,19 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) + # Check that we're getting HTML back. + content_type = None + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type = header[1] + self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result) + + # Check that the HTML we're getting is the one we expect on a successful renewal. + expected_html = self.hs.config.account_validity.account_renewed_html_content + self.assertEqual( + channel.result["body"], ensure_binary(expected_html), channel.result + ) + # Move 3 days forward. If the renewal failed, every authed request with # our access token should be denied from now, otherwise they should # succeed. @@ -381,6 +449,28 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) + def test_renewal_invalid_token(self): + # Hit the renewal endpoint with an invalid token and check that it behaves as + # expected, i.e. that it responds with 404 Not Found and the correct HTML. + url = "/_matrix/client/unstable/account_validity/renew?token=123" + request, channel = self.make_request(b"GET", url) + self.render(request) + self.assertEquals(channel.result["code"], b"404", channel.result) + + # Check that we're getting HTML back. + content_type = None + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type = header[1] + self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result) + + # Check that the HTML we're getting is the one we expect when using an + # invalid/unknown token. + expected_html = self.hs.config.account_validity.invalid_token_html_content + self.assertEqual( + channel.result["body"], ensure_binary(expected_html), channel.result + ) + def test_manual_email_send(self): self.email_attempts = [] @@ -415,7 +505,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): access_token=tok, ) self.render(request) - self.assertEqual(request.code, 200) + self.assertEqual(request.code, 200, channel.result) self.reactor.advance(datetime.timedelta(days=8).total_seconds()) diff --git a/tests/rulecheck/__init__.py b/tests/rulecheck/__init__.py new file mode 100644
index 0000000000..a354d38ca8 --- /dev/null +++ b/tests/rulecheck/__init__.py
@@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py new file mode 100644
index 0000000000..564fad0d77 --- /dev/null +++ b/tests/rulecheck/test_domainrulecheck.py
@@ -0,0 +1,342 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json + +import synapse.rest.admin +from synapse.config._base import ConfigError +from synapse.rest.client.v1 import login, room +from synapse.rulecheck.domain_rule_checker import DomainRuleChecker + +from tests import unittest +from tests.server import make_request, render + + +class DomainRuleCheckerTestCase(unittest.TestCase): + def test_allowed(self): + config = { + "default": False, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + }, + "domains_prevented_from_being_invited_to_published_rooms": ["target_two"] + } + check = DomainRuleChecker(config) + self.assertTrue( + check.user_may_invite( + "test:source_one", "test:target_one", None, "room", False + ) + ) + self.assertTrue( + check.user_may_invite( + "test:source_one", "test:target_two", None, "room", False + ) + ) + self.assertTrue( + check.user_may_invite( + "test:source_two", "test:target_two", None, "room", False + ) + ) + + # User can invite internal user to a published room + self.assertTrue( + check.user_may_invite( + "test:source_one", "test1:target_one", None, "room", False, True, + ) + ) + + # User can invite external user to a non-published room + self.assertTrue( + check.user_may_invite( + "test:source_one", "test:target_two", None, "room", False, False, + ) + ) + + def test_disallowed(self): + config = { + "default": True, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + "source_four": [], + }, + } + check = DomainRuleChecker(config) + self.assertFalse( + check.user_may_invite( + "test:source_one", "test:target_three", None, "room", False + ) + ) + self.assertFalse( + check.user_may_invite( + "test:source_two", "test:target_three", None, "room", False + ) + ) + self.assertFalse( + check.user_may_invite( + "test:source_two", "test:target_one", None, "room", False + ) + ) + self.assertFalse( + check.user_may_invite( + "test:source_four", "test:target_one", None, "room", False + ) + ) + + # User cannot invite external user to a published room + self.assertTrue( + check.user_may_invite( + "test:source_one", "test:target_two", None, "room", False, True, + ) + ) + + def test_default_allow(self): + config = { + "default": True, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + }, + } + check = DomainRuleChecker(config) + self.assertTrue( + check.user_may_invite( + "test:source_three", "test:target_one", None, "room", False + ) + ) + + def test_default_deny(self): + config = { + "default": False, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + }, + } + check = DomainRuleChecker(config) + self.assertFalse( + check.user_may_invite( + "test:source_three", "test:target_one", None, "room", False + ) + ) + + def test_config_parse(self): + config = { + "default": False, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + }, + } + self.assertEquals(config, DomainRuleChecker.parse_config(config)) + + def test_config_parse_failure(self): + config = { + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + } + } + self.assertRaises(ConfigError, DomainRuleChecker.parse_config, config) + + +class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + hijack_auth = False + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["trusted_third_party_id_servers"] = [ + "localhost", + ] + + config["spam_checker"] = { + "module": "synapse.rulecheck.domain_rule_checker.DomainRuleChecker", + "config": { + "default": True, + "domain_mapping": {}, + "can_only_join_rooms_with_invite": True, + "can_only_create_one_to_one_rooms": True, + "can_only_invite_during_room_creation": True, + "can_invite_by_third_party_id": False, + }, + } + + hs = self.setup_test_homeserver(config=config) + return hs + + def prepare(self, reactor, clock, hs): + self.admin_user_id = self.register_user("admin_user", "pass", admin=True) + self.admin_access_token = self.login("admin_user", "pass") + + self.normal_user_id = self.register_user("normal_user", "pass", admin=False) + self.normal_access_token = self.login("normal_user", "pass") + + self.other_user_id = self.register_user("other_user", "pass", admin=False) + + def test_admin_can_create_room(self): + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + def test_normal_user_cannot_create_empty_room(self): + channel = self._create_room(self.normal_access_token) + assert channel.result["code"] == b"403", channel.result + + def test_normal_user_cannot_create_room_with_multiple_invites(self): + channel = self._create_room( + self.normal_access_token, + content={"invite": [self.other_user_id, self.admin_user_id]}, + ) + assert channel.result["code"] == b"403", channel.result + + # Test that it correctly counts both normal and third party invites + channel = self._create_room( + self.normal_access_token, + content={ + "invite": [self.other_user_id], + "invite_3pid": [{"medium": "email", "address": "foo@example.com"}], + }, + ) + assert channel.result["code"] == b"403", channel.result + + # Test that it correctly rejects third party invites + channel = self._create_room( + self.normal_access_token, + content={ + "invite": [], + "invite_3pid": [{"medium": "email", "address": "foo@example.com"}], + }, + ) + assert channel.result["code"] == b"403", channel.result + + def test_normal_user_can_room_with_single_invites(self): + channel = self._create_room( + self.normal_access_token, content={"invite": [self.other_user_id]} + ) + assert channel.result["code"] == b"200", channel.result + + def test_cannot_join_public_room(self): + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + room_id = channel.json_body["room_id"] + + self.helper.join( + room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=403 + ) + + def test_can_join_invited_room(self): + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + room_id = channel.json_body["room_id"] + + self.helper.invite( + room_id, + src=self.admin_user_id, + targ=self.normal_user_id, + tok=self.admin_access_token, + ) + + self.helper.join( + room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200 + ) + + def test_cannot_invite(self): + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + room_id = channel.json_body["room_id"] + + self.helper.invite( + room_id, + src=self.admin_user_id, + targ=self.normal_user_id, + tok=self.admin_access_token, + ) + + self.helper.join( + room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200 + ) + + self.helper.invite( + room_id, + src=self.normal_user_id, + targ=self.other_user_id, + tok=self.normal_access_token, + expect_code=403, + ) + + def test_cannot_3pid_invite(self): + """Test that unbound 3pid invites get rejected. + """ + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + room_id = channel.json_body["room_id"] + + self.helper.invite( + room_id, + src=self.admin_user_id, + targ=self.normal_user_id, + tok=self.admin_access_token, + ) + + self.helper.join( + room_id, self.normal_user_id, + tok=self.normal_access_token, + expect_code=200, + ) + + self.helper.invite( + room_id, + src=self.normal_user_id, + targ=self.other_user_id, + tok=self.normal_access_token, + expect_code=403, + ) + + request, channel = self.make_request( + "POST", + "rooms/%s/invite" % (room_id), + { + "address": "foo@bar.com", + "medium": "email", + "id_server": "localhost" + }, + access_token=self.normal_access_token, + ) + self.render(request) + self.assertEqual(channel.code, 403, channel.result["body"]) + + def _create_room(self, token, content={}): + path = "/_matrix/client/r0/createRoom?access_token=%s" % (token,) + + request, channel = make_request( + self.hs.get_reactor(), + "POST", + path, + content=json.dumps(content).encode("utf8"), + ) + render(request, self.resource, self.hs.get_reactor()) + + return channel diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 45824bd3b2..c125a0d797 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py
@@ -34,20 +34,19 @@ class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_displayname(self): - yield self.store.create_profile(self.u_frank.localpart) - - yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank") + yield self.store.set_profile_displayname( + self.u_frank.localpart, "Frank", 1, + ) self.assertEquals( - "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart)) + "Frank", + (yield self.store.get_profile_displayname(self.u_frank.localpart)) ) @defer.inlineCallbacks def test_avatar_url(self): - yield self.store.create_profile(self.u_frank.localpart) - yield self.store.set_profile_avatar_url( - self.u_frank.localpart, "http://my.site/here" + self.u_frank.localpart, "http://my.site/here", 1, ) self.assertEquals(