summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py29
-rw-r--r--tests/api/test_ratelimiting.py4
-rw-r--r--tests/appservice/test_appservice.py1
-rw-r--r--tests/handlers/test_appservice.py20
-rw-r--r--tests/handlers/test_device.py2
-rw-r--r--tests/handlers/test_identity.py116
-rw-r--r--tests/handlers/test_message.py2
-rw-r--r--tests/handlers/test_oidc.py24
-rw-r--r--tests/handlers/test_profile.py10
-rw-r--r--tests/handlers/test_register.py108
-rw-r--r--tests/handlers/test_stats.py8
-rw-r--r--tests/handlers/test_user_directory.py135
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py2
-rw-r--r--tests/logging/__init__.py34
-rw-r--r--tests/logging/test_remote_handler.py169
-rw-r--r--tests/logging/test_structured.py214
-rw-r--r--tests/logging/test_terse_json.py253
-rw-r--r--tests/push/test_email.py31
-rw-r--r--tests/push/test_http.py22
-rw-r--r--tests/replication/_base.py13
-rw-r--r--tests/replication/tcp/streams/test_events.py2
-rw-r--r--tests/replication/test_pusher_shard.py2
-rw-r--r--tests/rest/admin/test_device.py17
-rw-r--r--tests/rest/admin/test_event_reports.py196
-rw-r--r--tests/rest/admin/test_media.py568
-rw-r--r--tests/rest/admin/test_user.py422
-rw-r--r--tests/rest/client/test_identity.py145
-rw-r--r--tests/rest/client/test_retention.py2
-rw-r--r--tests/rest/client/test_room_access_rules.py1083
-rw-r--r--tests/rest/client/v1/test_login.py16
-rw-r--r--tests/rest/client/v2_alpha/test_register.py259
-rw-r--r--tests/rulecheck/__init__.py14
-rw-r--r--tests/rulecheck/test_domainrulecheck.py334
-rw-r--r--tests/server.py8
-rw-r--r--tests/storage/test_cleanup_extrems.py6
-rw-r--r--tests/storage/test_client_ips.py2
-rw-r--r--tests/storage/test_event_metrics.py4
-rw-r--r--tests/storage/test_main.py2
-rw-r--r--tests/storage/test_profile.py4
-rw-r--r--tests/storage/test_registration.py10
-rw-r--r--tests/storage/test_roommember.py4
-rw-r--r--tests/test_federation.py4
-rw-r--r--tests/test_types.py22
-rw-r--r--tests/test_utils/event_injection.py2
-rw-r--r--tests/unittest.py4
-rw-r--r--tests/utils.py2
46 files changed, 3793 insertions, 538 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py

index cb6f29d670..0fd55f428a 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py
@@ -29,6 +29,7 @@ from synapse.api.errors import ( MissingClientTokenError, ResourceLimitError, ) +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import UserID from tests import unittest @@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): - user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} + user_info = TokenLookupResult( + user_id=self.test_user, token_id=5, device_id="device" + ) self.store.get_user_by_access_token = Mock( return_value=defer.succeed(user_info) ) @@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase): self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_user_missing_token(self): - user_info = {"name": self.test_user, "token_id": "ditto"} + user_info = TokenLookupResult(user_id=self.test_user, token_id=5) self.store.get_user_by_access_token = Mock( return_value=defer.succeed(user_info) ) @@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = Mock( return_value=defer.succeed( - {"name": "@baldrick:matrix.org", "device_id": "device"} + TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device") ) ) @@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase): user_info = yield defer.ensureDeferred( self.auth.get_user_by_access_token(macaroon.serialize()) ) - user = user_info["user"] - self.assertEqual(UserID.from_string(user_id), user) + self.assertEqual(user_id, user_info.user_id) # TODO: device_id should come from the macaroon, but currently comes # from the db. - self.assertEqual(user_info["device_id"], "device") + self.assertEqual(user_info.device_id, "device") @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): @@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase): user_info = yield defer.ensureDeferred( self.auth.get_user_by_access_token(serialized) ) - user = user_info["user"] - is_guest = user_info["is_guest"] - self.assertEqual(UserID.from_string(user_id), user) - self.assertTrue(is_guest) + self.assertEqual(user_id, user_info.user_id) + self.assertTrue(user_info.is_guest) self.store.get_user_by_id.assert_called_with(user_id) @defer.inlineCallbacks @@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase): if token != tok: return defer.succeed(None) return defer.succeed( - { - "name": USER_ID, - "is_guest": False, - "token_id": 1234, - "device_id": "DEVICE", - } + TokenLookupResult( + user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE", + ) ) self.store.get_user_by_access_token = get_user diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 1e1f30d790..fe504d0869 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py
@@ -43,7 +43,7 @@ class TestRatelimiter(unittest.TestCase): def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( - None, "example.com", id="foo", rate_limited=True, + None, "example.com", id="foo", rate_limited=True, sender="@as:example.com", ) as_requester = create_requester("@user:example.com", app_service=appservice) @@ -68,7 +68,7 @@ class TestRatelimiter(unittest.TestCase): def test_allowed_appservice_via_can_requester_do_action(self): appservice = ApplicationService( - None, "example.com", id="foo", rate_limited=False, + None, "example.com", id="foo", rate_limited=False, sender="@as:example.com", ) as_requester = create_requester("@user:example.com", app_service=appservice) diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 236b608d58..0bffeb1150 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py
@@ -31,6 +31,7 @@ class ApplicationServiceTestCase(unittest.TestCase): def setUp(self): self.service = ApplicationService( id="unique_identifier", + sender="@as:test", url="some_url", token="some_token", hostname="matrix.org", # only used by get_groups_for_user diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index ee4f3da31c..53763cd0f9 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py
@@ -42,7 +42,6 @@ class AppServiceHandlerTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() self.handler = ApplicationServicesHandler(hs) - @defer.inlineCallbacks def test_notify_interested_services(self): interested_service = self._mkservice(is_interested=True) services = [ @@ -62,14 +61,12 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred( - self.handler.notify_interested_services(RoomStreamToken(None, 0)) - ) + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.mock_scheduler.submit_event_for_as.assert_called_once_with( interested_service, event ) - @defer.inlineCallbacks def test_query_user_exists_unknown_user(self): user_id = "@someone:anywhere" services = [self._mkservice(is_interested=True)] @@ -83,12 +80,11 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred( - self.handler.notify_interested_services(RoomStreamToken(None, 0)) - ) + + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) - @defer.inlineCallbacks def test_query_user_exists_known_user(self): user_id = "@someone:anywhere" services = [self._mkservice(is_interested=True)] @@ -102,9 +98,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred( - self.handler.notify_interested_services(RoomStreamToken(None, 0)) - ) + + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.assertFalse( self.mock_as_api.query_user.called, "query_user called when it shouldn't have been.", diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 4512c51311..875aaec2c6 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py
@@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): # make sure that our device ID has changed user_info = self.get_success(self.auth.get_user_by_access_token(access_token)) - self.assertEqual(user_info["device_id"], retrieved_device_id) + self.assertEqual(user_info.device_id, retrieved_device_id) # make sure the device has the display name that was set from the login res = self.get_success(self.handler.get_device(user_id, retrieved_device_id)) diff --git a/tests/handlers/test_identity.py b/tests/handlers/test_identity.py new file mode 100644
index 0000000000..b7d340bcb8 --- /dev/null +++ b/tests/handlers/test_identity.py
@@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from twisted.internet import defer + +import synapse.rest.admin +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import account + +from tests import unittest + + +class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.address = "test@test" + self.is_server_name = "testis" + self.is_server_url = "https://testis" + self.rewritten_is_url = "https://int.testis" + + config = self.default_config() + config["trusted_third_party_id_servers"] = [self.is_server_name] + config["rewrite_identity_server_urls"] = { + self.is_server_url: self.rewritten_is_url + } + + mock_http_client = Mock(spec=["get_json", "post_json_get_json"]) + mock_http_client.get_json.side_effect = defer.succeed({}) + mock_http_client.post_json_get_json.return_value = defer.succeed( + {"address": self.address, "medium": "email"} + ) + + self.hs = self.setup_test_homeserver( + config=config, simple_http_client=mock_http_client + ) + + mock_blacklisting_http_client = Mock(spec=["get_json", "post_json_get_json"]) + mock_blacklisting_http_client.get_json.side_effect = defer.succeed({}) + mock_blacklisting_http_client.post_json_get_json.return_value = defer.succeed( + {"address": self.address, "medium": "email"} + ) + + # TODO: This class does not use a singleton to get it's http client + # This should be fixed for easier testing + # https://github.com/matrix-org/synapse-dinsic/issues/26 + self.hs.get_identity_handler().blacklisting_http_client = ( + mock_blacklisting_http_client + ) + + return self.hs + + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("kermit", "monkey") + + def test_rewritten_id_server(self): + """ + Tests that, when validating a 3PID association while rewriting the IS's server + name: + * the bind request is done against the rewritten hostname + * the original, non-rewritten, server name is stored in the database + """ + handler = self.hs.get_identity_handler() + post_json_get_json = handler.blacklisting_http_client.post_json_get_json + store = self.hs.get_datastore() + + creds = {"sid": "123", "client_secret": "some_secret"} + + # Make sure processing the mocked response goes through. + data = self.get_success( + handler.bind_threepid( + client_secret=creds["client_secret"], + sid=creds["sid"], + mxid=self.user_id, + id_server=self.is_server_name, + use_v2=False, + ) + ) + self.assertEqual(data.get("address"), self.address) + + # Check that the request was done against the rewritten server name. + post_json_get_json.assert_called_once_with( + "%s/_matrix/identity/api/v1/3pid/bind" % (self.rewritten_is_url,), + { + "sid": creds["sid"], + "client_secret": creds["client_secret"], + "mxid": self.user_id, + }, + headers={}, + ) + + # Check that the original server name is saved in the database instead of the + # rewritten one. + id_servers = self.get_success( + store.get_id_servers_user_bound(self.user_id, "email", self.address) + ) + self.assertEqual(id_servers, [self.is_server_name]) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 9f6f21a6e2..2e0fea04af 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py
@@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.info = self.get_success( self.hs.get_datastore().get_user_by_access_token(self.access_token,) ) - self.token_id = self.info["token_id"] + self.token_id = self.info.token_id self.requester = create_requester(self.user_id, access_token_id=self.token_id) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index b6f436c016..0d51705849 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py
@@ -394,7 +394,14 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) self.handler._auth_handler.complete_sso_login = simple_async_mock() request = Mock( - spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + spec=[ + "args", + "getCookie", + "addCookie", + "requestHeaders", + "getClientIP", + "get_user_agent", + ] ) code = "code" @@ -414,9 +421,8 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args[b"code"] = [code.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")] - request.requestHeaders = Mock(spec=["getRawHeaders"]) - request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")] request.getClientIP.return_value = ip_address + request.get_user_agent.return_value = user_agent self.get_success(self.handler.handle_oidc_callback(request)) @@ -621,7 +627,14 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) self.handler._auth_handler.complete_sso_login = simple_async_mock() request = Mock( - spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + spec=[ + "args", + "getCookie", + "addCookie", + "requestHeaders", + "getClientIP", + "get_user_agent", + ] ) state = "state" @@ -637,9 +650,8 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args[b"code"] = [b"code"] request.args[b"state"] = [state.encode("utf-8")] - request.requestHeaders = Mock(spec=["getRawHeaders"]) - request.requestHeaders.getRawHeaders.return_value = [b"Browser"] request.getClientIP.return_value = "10.0.0.1" + request.get_user_agent.return_value = "Browser" self.get_success(self.handler.handle_oidc_callback(request)) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index a69fa28b41..1999bcbb38 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py
@@ -65,7 +65,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_name(self): yield defer.ensureDeferred( - self.store.set_profile_displayname(self.frank.localpart, "Frank") + self.store.set_profile_displayname(self.frank.localpart, "Frank", 1) ) displayname = yield defer.ensureDeferred( @@ -113,7 +113,7 @@ class ProfileTestCase(unittest.TestCase): # Setting displayname for the first time is allowed yield defer.ensureDeferred( - self.store.set_profile_displayname(self.frank.localpart, "Frank") + self.store.set_profile_displayname(self.frank.localpart, "Frank", 1) ) self.assertEquals( @@ -166,7 +166,7 @@ class ProfileTestCase(unittest.TestCase): def test_incoming_fed_query(self): yield defer.ensureDeferred(self.store.create_profile("caroline")) yield defer.ensureDeferred( - self.store.set_profile_displayname("caroline", "Caroline") + self.store.set_profile_displayname("caroline", "Caroline", 1) ) response = yield defer.ensureDeferred( @@ -181,7 +181,7 @@ class ProfileTestCase(unittest.TestCase): def test_get_my_avatar(self): yield defer.ensureDeferred( self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + self.frank.localpart, "http://my.server/me.png", 1 ) ) avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank)) @@ -232,7 +232,7 @@ class ProfileTestCase(unittest.TestCase): # Setting displayname for the first time is allowed yield defer.ensureDeferred( self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + self.frank.localpart, "http://my.server/me.png", 1 ) ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index bdf3d0a8a2..252d3e7996 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py
@@ -18,9 +18,15 @@ from mock import Mock from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import Codes, ResourceLimitError, SynapseError +from synapse.http.site import SynapseRequest +from synapse.rest.client.v2_alpha.register import ( + _map_email_to_displayname, + register_servlets, +) from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, UserID, create_requester +from tests.server import FakeChannel from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import mock_getRawHeaders @@ -31,6 +37,10 @@ from .. import unittest class RegistrationTestCase(unittest.HomeserverTestCase): """ Tests the RegistrationHandler. """ + servlets = [ + register_servlets, + ] + def make_homeserver(self, reactor, clock): hs_config = self.default_config() @@ -517,6 +527,104 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertTrue(requester.shadow_banned) + def test_email_to_displayname_mapping(self): + """Test that custom emails are mapped to new user displaynames correctly""" + self._check_mapping( + "jack-phillips.rivers@big-org.com", "Jack-Phillips Rivers [Big-Org]" + ) + + self._check_mapping("bob.jones@matrix.org", "Bob Jones [Tchap Admin]") + + self._check_mapping("bob-jones.blabla@gouv.fr", "Bob-Jones Blabla [Gouv]") + + # Multibyte unicode characters + self._check_mapping( + "j\u030a\u0065an-poppy.seed@example.com", + "J\u030a\u0065an-Poppy Seed [Example]", + ) + + def _check_mapping(self, i, expected): + result = _map_email_to_displayname(i) + self.assertEqual(result, expected) + + @override_config( + { + "bind_new_user_emails_to_sydent": "https://is.example.com", + "registrations_require_3pid": ["email"], + "account_threepid_delegates": {}, + "email": { + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + }, + "public_baseurl": "http://localhost", + } + ) + def test_user_email_bound_via_sydent_internal_api(self): + """Tests that emails are bound after registration if this option is set""" + # Register user with an email address + email = "alice@example.com" + + # Mock Synapse's threepid validator + get_threepid_validation_session = Mock( + return_value=make_awaitable( + {"medium": "email", "address": email, "validated_at": 0} + ) + ) + self.store.get_threepid_validation_session = get_threepid_validation_session + delete_threepid_session = Mock(return_value=make_awaitable(None)) + self.store.delete_threepid_session = delete_threepid_session + + # Mock Synapse's http json post method to check for the internal bind call + post_json_get_json = Mock(return_value=make_awaitable(None)) + self.hs.get_simple_http_client().post_json_get_json = post_json_get_json + + # Retrieve a UIA session ID + channel = self.uia_register( + 401, {"username": "alice", "password": "nobodywillguessthis"} + ) + session_id = channel.json_body["session"] + + # Register our email address using the fake validation session above + channel = self.uia_register( + 200, + { + "username": "alice", + "password": "nobodywillguessthis", + "auth": { + "session": session_id, + "type": "m.login.email.identity", + "threepid_creds": {"sid": "blabla", "client_secret": "blablabla"}, + }, + }, + ) + self.assertEqual(channel.json_body["user_id"], "@alice:test") + + # Check that a bind attempt was made to our fake identity server + post_json_get_json.assert_called_with( + "https://is.example.com/_matrix/identity/internal/bind", + {"address": "alice@example.com", "medium": "email", "mxid": "@alice:test"}, + ) + + # Check that we stored a mapping of this bind + bound_threepids = self.get_success( + self.store.user_get_bound_threepids("@alice:test") + ) + self.assertListEqual(bound_threepids, [{"medium": "email", "address": email}]) + + def uia_register(self, expected_response: int, body: dict) -> FakeChannel: + """Make a register request.""" + request, channel = self.make_request( + "POST", "register", body + ) # type: SynapseRequest, FakeChannel + self.render(request) + + self.assertEqual(request.code, expected_response) + return channel + async def get_or_create_user( self, requester, localpart, displayname, password_hash=None ): diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 312c0a0d41..0229f58315 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py
@@ -21,8 +21,14 @@ from tests import unittest # The expected number of state events in a fresh public room. EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM = 5 + # The expected number of state events in a fresh private room. -EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6 +# +# Note: we increase this by 2 on the dinsic branch as we send +# a "im.vector.room.access_rules" state event into new private rooms, +# and an encryption state event as all private rooms are encrypted +# by default +EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 7 class StatsRoomTests(unittest.HomeserverTestCase): diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 87be94111f..48f750d357 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py
@@ -19,7 +19,7 @@ from twisted.internet import defer import synapse.rest.admin from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import user_directory +from synapse.rest.client.v2_alpha import account, account_validity, user_directory from synapse.storage.roommember import ProfileInfo from tests import unittest @@ -549,3 +549,136 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) == 0) + + +class UserInfoTestCase(unittest.FederatingHomeserverTestCase): + servlets = [ + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + account_validity.register_servlets, + synapse.rest.client.v2_alpha.user_directory.register_servlets, + account.register_servlets, + ] + + def default_config(self): + config = super().default_config() + + # Set accounts to expire after a week + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + } + return config + + def prepare(self, reactor, clock, hs): + super(UserInfoTestCase, self).prepare(reactor, clock, hs) + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() + + def test_user_info(self): + """Test /users/info for local users from the Client-Server API""" + user_one, user_two, user_three, user_three_token = self.setup_test_users() + + # Request info about each user from user_three + request, channel = self.make_request( + "POST", + path="/_matrix/client/unstable/users/info", + content={"user_ids": [user_one, user_two, user_three]}, + access_token=user_three_token, + shorthand=False, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + # Check the state of user_one matches + user_one_info = channel.json_body[user_one] + self.assertTrue(user_one_info["deactivated"]) + self.assertFalse(user_one_info["expired"]) + + # Check the state of user_two matches + user_two_info = channel.json_body[user_two] + self.assertFalse(user_two_info["deactivated"]) + self.assertTrue(user_two_info["expired"]) + + # Check the state of user_three matches + user_three_info = channel.json_body[user_three] + self.assertFalse(user_three_info["deactivated"]) + self.assertFalse(user_three_info["expired"]) + + def test_user_info_federation(self): + """Test that /users/info can be called from the Federation API, and + and that we can query remote users from the Client-Server API + """ + user_one, user_two, user_three, user_three_token = self.setup_test_users() + + # Request information about our local users from the perspective of a remote server + request, channel = self.make_request( + "POST", + path="/_matrix/federation/unstable/users/info", + content={"user_ids": [user_one, user_two, user_three]}, + ) + self.render(request) + self.assertEquals(200, channel.code) + + # Check the state of user_one matches + user_one_info = channel.json_body[user_one] + self.assertTrue(user_one_info["deactivated"]) + self.assertFalse(user_one_info["expired"]) + + # Check the state of user_two matches + user_two_info = channel.json_body[user_two] + self.assertFalse(user_two_info["deactivated"]) + self.assertTrue(user_two_info["expired"]) + + # Check the state of user_three matches + user_three_info = channel.json_body[user_three] + self.assertFalse(user_three_info["deactivated"]) + self.assertFalse(user_three_info["expired"]) + + def setup_test_users(self): + """Create an admin user and three test users, each with a different state""" + + # Create an admin user to expire other users with + self.register_user("admin", "adminpassword", admin=True) + admin_token = self.login("admin", "adminpassword") + + # Create three users + user_one = self.register_user("alice", "pass") + user_one_token = self.login("alice", "pass") + user_two = self.register_user("bob", "pass") + user_three = self.register_user("carl", "pass") + user_three_token = self.login("carl", "pass") + + # Deactivate user_one + self.deactivate(user_one, user_one_token) + + # Expire user_two + self.expire(user_two, admin_token) + + # Do nothing to user_three + + return user_one, user_two, user_three, user_three_token + + def expire(self, user_id_to_expire, admin_tok): + url = "/_matrix/client/unstable/admin/account_validity/validity" + request_data = { + "user_id": user_id_to_expire, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request, channel = self.make_request( + "POST", url, request_data, access_token=admin_tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def deactivate(self, user_id, tok): + request_data = { + "auth": {"type": "m.login.password", "user": user_id, "password": "pass"}, + "erase": False, + } + request, channel = self.make_request( + "POST", "account/deactivate", request_data, access_token=tok + ) + self.render(request) + self.assertEqual(request.code, 200) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 8b5ad4574f..c3f7a28dcc 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -101,7 +101,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.agent = MatrixFederationAgent( reactor=self.reactor, - tls_client_options_factory=self.tls_factory, + tls_client_options_factory=FederationPolicyForHTTPS(config), user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. _srv_resolver=self.mock_resolver, _well_known_resolver=self.well_known_resolver, diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py
index e69de29bb2..a58d51441c 100644 --- a/tests/logging/__init__.py +++ b/tests/logging/__init__.py
@@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + + +class LoggerCleanupMixin: + def get_logger(self, handler): + """ + Attach a handler to a logger and add clean-ups to remove revert this. + """ + # Create a logger and add the handler to it. + logger = logging.getLogger(__name__) + logger.addHandler(handler) + + # Ensure the logger actually logs something. + logger.setLevel(logging.INFO) + + # Ensure the logger gets cleaned-up appropriately. + self.addCleanup(logger.removeHandler, handler) + self.addCleanup(logger.setLevel, logging.NOTSET) + + return logger diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py new file mode 100644
index 0000000000..4bc27a1d7d --- /dev/null +++ b/tests/logging/test_remote_handler.py
@@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.test.proto_helpers import AccumulatingProtocol + +from synapse.logging import RemoteHandler + +from tests.logging import LoggerCleanupMixin +from tests.server import FakeTransport, get_clock +from tests.unittest import TestCase + + +def connect_logging_client(reactor, client_id): + # This is essentially tests.server.connect_client, but disabling autoflush on + # the client transport. This is necessary to avoid an infinite loop due to + # sending of data via the logging transport causing additional logs to be + # written. + factory = reactor.tcpClients.pop(client_id)[2] + client = factory.buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, reactor)) + client.makeConnection(FakeTransport(server, reactor, autoflush=False)) + + return client, server + + +class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): + def setUp(self): + self.reactor, _ = get_clock() + + def test_log_output(self): + """ + The remote handler delivers logs over TCP. + """ + handler = RemoteHandler("127.0.0.1", 9000, _reactor=self.reactor) + logger = self.get_logger(handler) + + logger.info("Hello there, %s!", "wally") + + # Trigger the connection + client, server = connect_logging_client(self.reactor, 0) + + # Trigger data being sent + client.transport.flush() + + # One log message, with a single trailing newline + logs = server.data.decode("utf8").splitlines() + self.assertEqual(len(logs), 1) + self.assertEqual(server.data.count(b"\n"), 1) + + # Ensure the data passed through properly. + self.assertEqual(logs[0], "Hello there, wally!") + + def test_log_backpressure_debug(self): + """ + When backpressure is hit, DEBUG logs will be shed. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send some debug messages + for i in range(0, 3): + logger.debug("debug %s" % (i,)) + + # Send a bunch of useful messages + for i in range(0, 7): + logger.info("info %s" % (i,)) + + # The last debug message pushes it past the maximum buffer + logger.debug("too much debug") + + # Allow the reconnection + client, server = connect_logging_client(self.reactor, 0) + client.transport.flush() + + # Only the 7 infos made it through, the debugs were elided + logs = server.data.splitlines() + self.assertEqual(len(logs), 7) + self.assertNotIn(b"debug", server.data) + + def test_log_backpressure_info(self): + """ + When backpressure is hit, DEBUG and INFO logs will be shed. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send some debug messages + for i in range(0, 3): + logger.debug("debug %s" % (i,)) + + # Send a bunch of useful messages + for i in range(0, 10): + logger.warning("warn %s" % (i,)) + + # Send a bunch of info messages + for i in range(0, 3): + logger.info("info %s" % (i,)) + + # The last debug message pushes it past the maximum buffer + logger.debug("too much debug") + + # Allow the reconnection + client, server = connect_logging_client(self.reactor, 0) + client.transport.flush() + + # The 10 warnings made it through, the debugs and infos were elided + logs = server.data.splitlines() + self.assertEqual(len(logs), 10) + self.assertNotIn(b"debug", server.data) + self.assertNotIn(b"info", server.data) + + def test_log_backpressure_cut_middle(self): + """ + When backpressure is hit, and no more DEBUG and INFOs cannot be culled, + it will cut the middle messages out. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send a bunch of useful messages + for i in range(0, 20): + logger.warning("warn %s" % (i,)) + + # Allow the reconnection + client, server = connect_logging_client(self.reactor, 0) + client.transport.flush() + + # The first five and last five warnings made it through, the debugs and + # infos were elided + logs = server.data.decode("utf8").splitlines() + self.assertEqual( + ["warn %s" % (i,) for i in range(5)] + + ["warn %s" % (i,) for i in range(15, 20)], + logs, + ) + + def test_cancel_connection(self): + """ + Gracefully handle the connection being cancelled. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send a message. + logger.info("Hello there, %s!", "wally") + + # Do not accept the connection and shutdown. This causes the pending + # connection to be cancelled (and should not raise any exceptions). + handler.close() diff --git a/tests/logging/test_structured.py b/tests/logging/test_structured.py deleted file mode 100644
index d36f5f426c..0000000000 --- a/tests/logging/test_structured.py +++ /dev/null
@@ -1,214 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import os.path -import shutil -import sys -import textwrap - -from twisted.logger import Logger, eventAsText, eventsFromJSONLogFile - -from synapse.config.logger import setup_logging -from synapse.logging._structured import setup_structured_logging -from synapse.logging.context import LoggingContext - -from tests.unittest import DEBUG, HomeserverTestCase - - -class FakeBeginner: - def beginLoggingTo(self, observers, **kwargs): - self.observers = observers - - -class StructuredLoggingTestBase: - """ - Test base that registers a cleanup handler to reset the stdlib log handler - to 'unset'. - """ - - def prepare(self, reactor, clock, hs): - def _cleanup(): - logging.getLogger("synapse").setLevel(logging.NOTSET) - - self.addCleanup(_cleanup) - - -class StructuredLoggingTestCase(StructuredLoggingTestBase, HomeserverTestCase): - """ - Tests for Synapse's structured logging support. - """ - - def test_output_to_json_round_trip(self): - """ - Synapse logs can be outputted to JSON and then read back again. - """ - temp_dir = self.mktemp() - os.mkdir(temp_dir) - self.addCleanup(shutil.rmtree, temp_dir) - - json_log_file = os.path.abspath(os.path.join(temp_dir, "out.json")) - - log_config = { - "drains": {"jsonfile": {"type": "file_json", "location": json_log_file}} - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - # Make a logger and send an event - logger = Logger( - namespace="tests.logging.test_structured", observer=beginner.observers[0] - ) - logger.info("Hello there, {name}!", name="wally") - - # Read the log file and check it has the event we sent - with open(json_log_file, "r") as f: - logged_events = list(eventsFromJSONLogFile(f)) - self.assertEqual(len(logged_events), 1) - - # The event pulled from the file should render fine - self.assertEqual( - eventAsText(logged_events[0], includeTimestamp=False), - "[tests.logging.test_structured#info] Hello there, wally!", - ) - - def test_output_to_text(self): - """ - Synapse logs can be outputted to text. - """ - temp_dir = self.mktemp() - os.mkdir(temp_dir) - self.addCleanup(shutil.rmtree, temp_dir) - - log_file = os.path.abspath(os.path.join(temp_dir, "out.log")) - - log_config = {"drains": {"file": {"type": "file", "location": log_file}}} - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - # Make a logger and send an event - logger = Logger( - namespace="tests.logging.test_structured", observer=beginner.observers[0] - ) - logger.info("Hello there, {name}!", name="wally") - - # Read the log file and check it has the event we sent - with open(log_file, "r") as f: - logged_events = f.read().strip().split("\n") - self.assertEqual(len(logged_events), 1) - - # The event pulled from the file should render fine - self.assertTrue( - logged_events[0].endswith( - " - tests.logging.test_structured - INFO - None - Hello there, wally!" - ) - ) - - def test_collects_logcontext(self): - """ - Test that log outputs have the attached logging context. - """ - log_config = {"drains": {}} - - # Begin the logger with our config - beginner = FakeBeginner() - publisher = setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - logs = [] - - publisher.addObserver(logs.append) - - # Make a logger and send an event - logger = Logger( - namespace="tests.logging.test_structured", observer=beginner.observers[0] - ) - - with LoggingContext("testcontext", request="somereq"): - logger.info("Hello there, {name}!", name="steve") - - self.assertEqual(len(logs), 1) - self.assertEqual(logs[0]["request"], "somereq") - - -class StructuredLoggingConfigurationFileTestCase( - StructuredLoggingTestBase, HomeserverTestCase -): - def make_homeserver(self, reactor, clock): - - tempdir = self.mktemp() - os.mkdir(tempdir) - log_config_file = os.path.abspath(os.path.join(tempdir, "log.config.yaml")) - self.homeserver_log = os.path.abspath(os.path.join(tempdir, "homeserver.log")) - - config = self.default_config() - config["log_config"] = log_config_file - - with open(log_config_file, "w") as f: - f.write( - textwrap.dedent( - """\ - structured: true - - drains: - file: - type: file_json - location: %s - """ - % (self.homeserver_log,) - ) - ) - - self.addCleanup(self._sys_cleanup) - - return self.setup_test_homeserver(config=config) - - def _sys_cleanup(self): - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ - - # Do not remove! We need the logging system to be set other than WARNING. - @DEBUG - def test_log_output(self): - """ - When a structured logging config is given, Synapse will use it. - """ - beginner = FakeBeginner() - publisher = setup_logging(self.hs, self.hs.config, logBeginner=beginner) - - # Make a logger and send an event - logger = Logger(namespace="tests.logging.test_structured", observer=publisher) - - with LoggingContext("testcontext", request="somereq"): - logger.info("Hello there, {name}!", name="steve") - - with open(self.homeserver_log, "r") as f: - logged_events = [ - eventAsText(x, includeTimestamp=False) for x in eventsFromJSONLogFile(f) - ] - - logs = "\n".join(logged_events) - self.assertTrue("***** STARTING SERVER *****" in logs) - self.assertTrue("Hello there, steve!" in logs) diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index fd128b88e0..73f469b802 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py
@@ -14,57 +14,33 @@ # limitations under the License. import json -from collections import Counter +import logging +from io import StringIO -from twisted.logger import Logger +from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter -from synapse.logging._structured import setup_structured_logging +from tests.logging import LoggerCleanupMixin +from tests.unittest import TestCase -from tests.server import connect_client -from tests.unittest import HomeserverTestCase -from .test_structured import FakeBeginner, StructuredLoggingTestBase - - -class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase): - def test_log_output(self): +class TerseJsonTestCase(LoggerCleanupMixin, TestCase): + def test_terse_json_output(self): """ - The Terse JSON outputter delivers simplified structured logs over TCP. + The Terse JSON formatter converts log messages to JSON. """ - log_config = { - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - } - } - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - logger = Logger( - namespace="tests.logging.test_terse_json", observer=beginner.observers[0] - ) - logger.info("Hello there, {name}!", name="wally") - - # Trigger the connection - self.pump() + output = StringIO() - _, server = connect_client(self.reactor, 0) + handler = logging.StreamHandler(output) + handler.setFormatter(TerseJsonFormatter()) + logger = self.get_logger(handler) - # Trigger data being sent - self.pump() + logger.info("Hello there, %s!", "wally") - # One log message, with a single trailing newline - logs = server.data.decode("utf8").splitlines() + # One log message, with a single trailing newline. + data = output.getvalue() + logs = data.splitlines() self.assertEqual(len(logs), 1) - self.assertEqual(server.data.count(b"\n"), 1) - + self.assertEqual(data.count("\n"), 1) log = json.loads(logs[0]) # The terse logger should give us these keys. @@ -72,163 +48,74 @@ class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase): "log", "time", "level", - "log_namespace", - "request", - "scope", - "server_name", - "name", + "namespace", ] self.assertCountEqual(log.keys(), expected_log_keys) + self.assertEqual(log["log"], "Hello there, wally!") - # It contains the data we expect. - self.assertEqual(log["name"], "wally") - - def test_log_backpressure_debug(self): + def test_extra_data(self): """ - When backpressure is hit, DEBUG logs will be shed. + Additional information can be included in the structured logging. """ - log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - "maximum_buffer": 10, - } - }, - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, - self.hs.config, - log_config, - logBeginner=beginner, - redirect_stdlib_logging=False, - ) - - logger = Logger( - namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] - ) + output = StringIO() - # Send some debug messages - for i in range(0, 3): - logger.debug("debug %s" % (i,)) + handler = logging.StreamHandler(output) + handler.setFormatter(TerseJsonFormatter()) + logger = self.get_logger(handler) - # Send a bunch of useful messages - for i in range(0, 7): - logger.info("test message %s" % (i,)) - - # The last debug message pushes it past the maximum buffer - logger.debug("too much debug") - - # Allow the reconnection - _, server = connect_client(self.reactor, 0) - self.pump() - - # Only the 7 infos made it through, the debugs were elided - logs = server.data.splitlines() - self.assertEqual(len(logs), 7) - - def test_log_backpressure_info(self): - """ - When backpressure is hit, DEBUG and INFO logs will be shed. - """ - log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - "maximum_buffer": 10, - } - }, - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, - self.hs.config, - log_config, - logBeginner=beginner, - redirect_stdlib_logging=False, - ) - - logger = Logger( - namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] + logger.info( + "Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True} ) - # Send some debug messages - for i in range(0, 3): - logger.debug("debug %s" % (i,)) - - # Send a bunch of useful messages - for i in range(0, 10): - logger.warn("test warn %s" % (i,)) - - # Send a bunch of info messages - for i in range(0, 3): - logger.info("test message %s" % (i,)) - - # The last debug message pushes it past the maximum buffer - logger.debug("too much debug") - - # Allow the reconnection - client, server = connect_client(self.reactor, 0) - self.pump() + # One log message, with a single trailing newline. + data = output.getvalue() + logs = data.splitlines() + self.assertEqual(len(logs), 1) + self.assertEqual(data.count("\n"), 1) + log = json.loads(logs[0]) - # The 10 warnings made it through, the debugs and infos were elided - logs = list(map(json.loads, server.data.decode("utf8").splitlines())) - self.assertEqual(len(logs), 10) + # The terse logger should give us these keys. + expected_log_keys = [ + "log", + "time", + "level", + "namespace", + # The additional keys given via extra. + "foo", + "int", + "bool", + ] + self.assertCountEqual(log.keys(), expected_log_keys) - self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10}) + # Check the values of the extra fields. + self.assertEqual(log["foo"], "bar") + self.assertEqual(log["int"], 3) + self.assertIs(log["bool"], True) - def test_log_backpressure_cut_middle(self): + def test_json_output(self): """ - When backpressure is hit, and no more DEBUG and INFOs cannot be culled, - it will cut the middle messages out. + The Terse JSON formatter converts log messages to JSON. """ - log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - "maximum_buffer": 10, - } - }, - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, - self.hs.config, - log_config, - logBeginner=beginner, - redirect_stdlib_logging=False, - ) + output = StringIO() - logger = Logger( - namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] - ) + handler = logging.StreamHandler(output) + handler.setFormatter(JsonFormatter()) + logger = self.get_logger(handler) - # Send a bunch of useful messages - for i in range(0, 20): - logger.warn("test warn", num=i) + logger.info("Hello there, %s!", "wally") - # Allow the reconnection - client, server = connect_client(self.reactor, 0) - self.pump() + # One log message, with a single trailing newline. + data = output.getvalue() + logs = data.splitlines() + self.assertEqual(len(logs), 1) + self.assertEqual(data.count("\n"), 1) + log = json.loads(logs[0]) - # The first five and last five warnings made it through, the debugs and - # infos were elided - logs = list(map(json.loads, server.data.decode("utf8").splitlines())) - self.assertEqual(len(logs), 10) - self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10}) - self.assertEqual([0, 1, 2, 3, 4, 15, 16, 17, 18, 19], [x["num"] for x in logs]) + # The terse logger should give us these keys. + expected_log_keys = [ + "log", + "level", + "namespace", + ] + self.assertCountEqual(log.keys(), expected_log_keys) + self.assertEqual(log["log"], "Hello there, wally!") diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 55545d9341..bcdcafa5a9 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py
@@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(self.access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.pusher = self.get_success( self.hs.get_pusherpool().add_pusher( @@ -131,6 +131,35 @@ class EmailPusherTests(HomeserverTestCase): # We should get emailed about that message self._check_for_mail() + def test_invite_sends_email(self): + # Create a room and invite the user to it + room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token) + self.helper.invite( + room=room, + src=self.others[0].id, + tok=self.others[0].token, + targ=self.user_id, + ) + + # We should get emailed about the invite + self._check_for_mail() + + def test_invite_to_empty_room_sends_email(self): + # Create a room and invite the user to it + room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token) + self.helper.invite( + room=room, + src=self.others[0].id, + tok=self.others[0].token, + targ=self.user_id, + ) + + # Then have the original user leave + self.helper.leave(room, self.others[0].id, tok=self.others[0].token) + + # We should get emailed about the invite + self._check_for_mail() + def test_multiple_members_email(self): # We want to test multiple notifications, so we pause processing of push # while we send messages. diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index b567868b02..826cebbf0c 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py
@@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -346,8 +346,8 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual(len(self.push_attempts), 2) self.assertEqual(self.push_attempts[1][1], "example.com") - # check that this is low-priority - self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + # check that this is high-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") def test_sends_high_priority_for_mention(self): """ @@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -418,8 +418,8 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual(len(self.push_attempts), 2) self.assertEqual(self.push_attempts[1][1], "example.com") - # check that this is low-priority - self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + # check that this is high-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") def test_sends_high_priority_for_atroom(self): """ @@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -497,5 +497,5 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual(len(self.push_attempts), 2) self.assertEqual(self.push_attempts[1][1], "example.com") - # check that this is low-priority - self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + # check that this is high-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 093e2faac7..5c633ac6df 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py
@@ -16,7 +16,6 @@ import logging from typing import Any, Callable, List, Optional, Tuple import attr -import hiredis from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.protocol import Protocol @@ -39,12 +38,22 @@ from synapse.util import Clock from tests import unittest from tests.server import FakeTransport, render +try: + import hiredis +except ImportError: + hiredis = None + logger = logging.getLogger(__name__) class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + # hiredis is an optional dependency so we don't want to require it for running + # the tests. + if not hiredis: + skip = "Requires hiredis" + servlets = [ streams.register_servlets, ] @@ -269,7 +278,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): homeserver_to_use=GenericWorkerServer, config=config, reactor=self.reactor, - **kwargs + **kwargs, ) # If the instance is in the `instance_map` config then workers may try diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index c9998e88e6..bad0df08cf 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py
@@ -449,7 +449,7 @@ class EventsStreamTestCase(BaseStreamTestCase): sender=sender, type="test_event", content={"body": body}, - **kwargs + **kwargs, ) ) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 2bdc6edbb1..67c27a089f 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py
@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): user_dict = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_dict["token_id"] + token_id = user_dict.token_id self.get_success( self.hs.get_pusherpool().add_pusher( diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index 92c9058887..d89eb90cfe 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py
@@ -393,6 +393,22 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) + def test_user_has_no_devices(self): + """ + Tests that a normal lookup for devices is successfully + if user has no devices + """ + + # Get devices + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["devices"])) + def test_get_devices(self): """ Tests that a normal lookup for devices is successfully @@ -409,6 +425,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(number_devices, channel.json_body["total"]) self.assertEqual(number_devices, len(channel.json_body["devices"])) self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"]) # Check that all fields are available diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index bf79086f78..303622217f 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py
@@ -70,6 +70,16 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/event_reports" + def test_no_auth(self): + """ + Try to get an event report without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + def test_requester_is_no_admin(self): """ If the user is not a server admin, an error 403 is returned. @@ -266,7 +276,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): def test_limit_is_negative(self): """ - Testing that a negative list parameter returns a 400 + Testing that a negative limit parameter returns a 400 """ request, channel = self.make_request( @@ -360,7 +370,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) def _check_fields(self, content): - """Checks that all attributes are present in a event report + """Checks that all attributes are present in an event report """ for c in content: self.assertIn("id", c) @@ -368,15 +378,175 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertIn("room_id", c) self.assertIn("event_id", c) self.assertIn("user_id", c) - self.assertIn("reason", c) - self.assertIn("content", c) self.assertIn("sender", c) - self.assertIn("room_alias", c) - self.assertIn("event_json", c) - self.assertIn("score", c["content"]) - self.assertIn("reason", c["content"]) - self.assertIn("auth_events", c["event_json"]) - self.assertIn("type", c["event_json"]) - self.assertIn("room_id", c["event_json"]) - self.assertIn("sender", c["event_json"]) - self.assertIn("content", c["event_json"]) + self.assertIn("canonical_alias", c) + self.assertIn("name", c) + self.assertIn("score", c) + self.assertIn("reason", c) + + +class EventReportDetailTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + report_event.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id1 = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok, is_public=True + ) + self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok) + + self._create_event_and_report( + room_id=self.room_id1, user_tok=self.other_user_tok, + ) + + # first created event report gets `id`=2 + self.url = "/_synapse/admin/v1/event_reports/2" + + def test_no_auth(self): + """ + Try to get event report without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + + request, channel = self.make_request( + "GET", self.url, access_token=self.other_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_default_success(self): + """ + Testing get a reported event + """ + + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self._check_fields(channel.json_body) + + def test_invalid_report_id(self): + """ + Testing that an invalid `report_id` returns a 400. + """ + + # `report_id` is negative + request, channel = self.make_request( + "GET", + "/_synapse/admin/v1/event_reports/-123", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is a non-numerical string + request, channel = self.make_request( + "GET", + "/_synapse/admin/v1/event_reports/abcdef", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is undefined + request, channel = self.make_request( + "GET", + "/_synapse/admin/v1/event_reports/", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + def test_report_id_not_found(self): + """ + Testing that a not existing `report_id` returns a 404. + """ + + request, channel = self.make_request( + "GET", + "/_synapse/admin/v1/event_reports/123", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual("Event report not found", channel.json_body["error"]) + + def _create_event_and_report(self, room_id, user_tok): + """Create and report events + """ + resp = self.helper.send(room_id, tok=user_tok) + event_id = resp["event_id"] + + request, channel = self.make_request( + "POST", + "rooms/%s/report/%s" % (room_id, event_id), + json.dumps({"score": -100, "reason": "this makes me sad"}), + access_token=user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def _check_fields(self, content): + """Checks that all attributes are present in a event report + """ + self.assertIn("id", content) + self.assertIn("received_ts", content) + self.assertIn("room_id", content) + self.assertIn("event_id", content) + self.assertIn("user_id", content) + self.assertIn("sender", content) + self.assertIn("canonical_alias", content) + self.assertIn("name", content) + self.assertIn("event_json", content) + self.assertIn("score", content) + self.assertIn("reason", content) + self.assertIn("auth_events", content["event_json"]) + self.assertIn("type", content["event_json"]) + self.assertIn("room_id", content["event_json"]) + self.assertIn("sender", content["event_json"]) + self.assertIn("content", content["event_json"]) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py new file mode 100644
index 0000000000..721fa1ed51 --- /dev/null +++ b/tests/rest/admin/test_media.py
@@ -0,0 +1,568 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from binascii import unhexlify + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login, profile, room +from synapse.rest.media.v1.filepath import MediaFilePaths + +from tests import unittest + + +class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_media_repo, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_device_handler() + self.media_repo = hs.get_media_repository_resource() + self.server_name = hs.hostname + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.filepaths = MediaFilePaths(hs.config.media_store_path) + + def test_no_auth(self): + """ + Try to delete media without authentication. + """ + url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") + + request, channel = self.make_request("DELETE", url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") + + request, channel = self.make_request( + "DELETE", url, access_token=self.other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_media_does_not_exist(self): + """ + Tests that a lookup for a media that does not exist returns a 404 + """ + url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") + + request, channel = self.make_request( + "DELETE", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_media_is_not_local(self): + """ + Tests that a lookup for a media that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345") + + request, channel = self.make_request( + "DELETE", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only delete local media", channel.json_body["error"]) + + def test_delete_media(self): + """ + Tests that delete a media is successfully + """ + + download_resource = self.media_repo.children[b"download"] + upload_resource = self.media_repo.children[b"upload"] + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + response = self.helper.upload_media( + upload_resource, image_data, tok=self.admin_user_tok, expect_code=200 + ) + # Extract media ID from the response + server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' + server_name, media_id = server_and_media_id.split("/") + + self.assertEqual(server_name, self.server_name) + + # Attempt to access media + request, channel = self.make_request( + "GET", + server_and_media_id, + shorthand=False, + access_token=self.admin_user_tok, + ) + request.render(download_resource) + self.pump(1.0) + + # Should be successful + self.assertEqual( + 200, + channel.code, + msg=( + "Expected to receive a 200 on accessing media: %s" % server_and_media_id + ), + ) + + # Test if the file exists + local_path = self.filepaths.local_media_filepath(media_id) + self.assertTrue(os.path.exists(local_path)) + + url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id) + + # Delete media + request, channel = self.make_request( + "DELETE", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + media_id, channel.json_body["deleted_media"][0], + ) + + # Attempt to access media + request, channel = self.make_request( + "GET", + server_and_media_id, + shorthand=False, + access_token=self.admin_user_tok, + ) + request.render(download_resource) + self.pump(1.0) + self.assertEqual( + 404, + channel.code, + msg=( + "Expected to receive a 404 on accessing deleted media: %s" + % server_and_media_id + ), + ) + + # Test if the file is deleted + self.assertFalse(os.path.exists(local_path)) + + +class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_media_repo, + login.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_device_handler() + self.media_repo = hs.get_media_repository_resource() + self.server_name = hs.hostname + self.clock = hs.clock + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.filepaths = MediaFilePaths(hs.config.media_store_path) + self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name + + def test_no_auth(self): + """ + Try to delete media without authentication. + """ + + request, channel = self.make_request("POST", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + request, channel = self.make_request( + "POST", self.url, access_token=self.other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_media_is_not_local(self): + """ + Tests that a lookup for media that is not local returns a 400 + """ + url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain" + + request, channel = self.make_request( + "POST", url + "?before_ts=1234", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only delete local media", channel.json_body["error"]) + + def test_missing_parameter(self): + """ + If the parameter `before_ts` is missing, an error is returned. + """ + request, channel = self.make_request( + "POST", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Missing integer query parameter b'before_ts'", channel.json_body["error"] + ) + + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + request, channel = self.make_request( + "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Query parameter before_ts must be a string representing a positive integer.", + channel.json_body["error"], + ) + + request, channel = self.make_request( + "POST", + self.url + "?before_ts=1234&size_gt=-1234", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Query parameter size_gt must be a string representing a positive integer.", + channel.json_body["error"], + ) + + request, channel = self.make_request( + "POST", + self.url + "?before_ts=1234&keep_profiles=not_bool", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual( + "Boolean query parameter b'keep_profiles' must be one of ['true', 'false']", + channel.json_body["error"], + ) + + def test_delete_media_never_accessed(self): + """ + Tests that media deleted if it is older than `before_ts` and never accessed + `last_access_ts` is `NULL` and `created_ts` < `before_ts` + """ + + # upload and do not access + server_and_media_id = self._create_media() + self.pump(1.0) + + # test that the file exists + media_id = server_and_media_id.split("/")[1] + local_path = self.filepaths.local_media_filepath(media_id) + self.assertTrue(os.path.exists(local_path)) + + # timestamp after upload/create + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + media_id, channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def test_keep_media_by_date(self): + """ + Tests that media is not deleted if it is newer than `before_ts` + """ + + # timestamp before upload + now_ms = self.clock.time_msec() + server_and_media_id = self._create_media() + + self._access_media(server_and_media_id) + + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + self._access_media(server_and_media_id) + + # timestamp after upload + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def test_keep_media_by_size(self): + """ + Tests that media is not deleted if its size is smaller than or equal + to `size_gt` + """ + server_and_media_id = self._create_media() + + self._access_media(server_and_media_id) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&size_gt=67", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + self._access_media(server_and_media_id) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&size_gt=66", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def test_keep_media_by_user_avatar(self): + """ + Tests that we do not delete media if is used as a user avatar + Tests parameter `keep_profiles` + """ + server_and_media_id = self._create_media() + + self._access_media(server_and_media_id) + + # set media as avatar + request, channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.admin_user,), + content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + self._access_media(server_and_media_id) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def test_keep_media_by_room_avatar(self): + """ + Tests that we do not delete media if it is used as a room avatar + Tests parameter `keep_profiles` + """ + server_and_media_id = self._create_media() + + self._access_media(server_and_media_id) + + # set media as room avatar + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + request, channel = self.make_request( + "PUT", + "/rooms/%s/state/m.room.avatar" % (room_id,), + content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + self._access_media(server_and_media_id) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def _create_media(self): + """ + Create a media and return media_id and server_and_media_id + """ + upload_resource = self.media_repo.children[b"upload"] + # file size is 67 Byte + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + response = self.helper.upload_media( + upload_resource, image_data, tok=self.admin_user_tok, expect_code=200 + ) + # Extract media ID from the response + server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' + server_name = server_and_media_id.split("/")[0] + + # Check that new media is a local and not remote + self.assertEqual(server_name, self.server_name) + + return server_and_media_id + + def _access_media(self, server_and_media_id, expect_success=True): + """ + Try to access a media and check the result + """ + download_resource = self.media_repo.children[b"download"] + + media_id = server_and_media_id.split("/")[1] + local_path = self.filepaths.local_media_filepath(media_id) + + request, channel = self.make_request( + "GET", + server_and_media_id, + shorthand=False, + access_token=self.admin_user_tok, + ) + request.render(download_resource) + self.pump(1.0) + + if expect_success: + self.assertEqual( + 200, + channel.code, + msg=( + "Expected to receive a 200 on accessing media: %s" + % server_and_media_id + ), + ) + # Test that the file exists + self.assertTrue(os.path.exists(local_path)) + else: + self.assertEqual( + 404, + channel.code, + msg=( + "Expected to receive a 404 on accessing deleted media: %s" + % (server_and_media_id) + ), + ) + # Test that the file is deleted + self.assertFalse(os.path.exists(local_path)) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 98d0623734..7df32e5093 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -17,6 +17,7 @@ import hashlib import hmac import json import urllib.parse +from binascii import unhexlify from mock import Mock @@ -1016,7 +1017,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, - sync.register_servlets, room.register_servlets, ] @@ -1082,6 +1082,21 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) + def test_no_memberships(self): + """ + Tests that a normal lookup for rooms is successfully + if user has no memberships + """ + # Get rooms + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["joined_rooms"])) + def test_get_rooms(self): """ Tests that a normal lookup for rooms is successfully @@ -1101,3 +1116,408 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) + + +class PushersRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.url = "/_synapse/admin/v1/users/%s/pushers" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to list pushers of an user without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + request, channel = self.make_request( + "GET", self.url, access_token=other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = "/_synapse/admin/v1/users/@unknown_person:test/pushers" + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers" + + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + def test_get_pushers(self): + """ + Tests that a normal lookup for pushers is successfully + """ + + # Get pushers + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + # Register the pusher + other_user_token = self.login("user", "pass") + user_tuple = self.get_success( + self.store.get_user_by_access_token(other_user_token) + ) + token_id = user_tuple.token_id + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=self.other_user, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Get pushers + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + + for p in channel.json_body["pushers"]: + self.assertIn("pushkey", p) + self.assertIn("kind", p) + self.assertIn("app_id", p) + self.assertIn("app_display_name", p) + self.assertIn("device_display_name", p) + self.assertIn("profile_tag", p) + self.assertIn("lang", p) + self.assertIn("url", p["data"]) + + +class UserMediaRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.media_repo = hs.get_media_repository_resource() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.url = "/_synapse/admin/v1/users/%s/media" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to list media of an user without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + request, channel = self.make_request( + "GET", self.url, access_token=other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = "/_synapse/admin/v1/users/@unknown_person:test/media" + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" + + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + def test_limit(self): + """ + Testing list of media with limit + """ + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + request, channel = self.make_request( + "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), 5) + self.assertEqual(channel.json_body["next_token"], 5) + self._check_fields(channel.json_body["media"]) + + def test_from(self): + """ + Testing list of media with a defined starting point (from) + """ + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + request, channel = self.make_request( + "GET", self.url + "?from=5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["media"]) + + def test_limit_and_from(self): + """ + Testing list of media with a defined starting point and limit + """ + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + request, channel = self.make_request( + "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(channel.json_body["next_token"], 15) + self.assertEqual(len(channel.json_body["media"]), 10) + self._check_fields(channel.json_body["media"]) + + def test_limit_is_negative(self): + """ + Testing that a negative limit parameter returns a 400 + """ + + request, channel = self.make_request( + "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_from_is_negative(self): + """ + Testing that a negative from parameter returns a 400 + """ + + request, channel = self.make_request( + "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + # `next_token` does not appear + # Number of results is the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), number_media) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), number_media) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), 19) + self.assertEqual(channel.json_body["next_token"], 19) + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + request, channel = self.make_request( + "GET", self.url + "?from=19", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def test_user_has_no_media(self): + """ + Tests that a normal lookup for media is successfully + if user has no media created + """ + + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["media"])) + + def test_get_media(self): + """ + Tests that a normal lookup for media is successfully + """ + + number_media = 5 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(number_media, channel.json_body["total"]) + self.assertEqual(number_media, len(channel.json_body["media"])) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["media"]) + + def _create_media(self, user_token, number_media): + """ + Create a number of media for a specific user + """ + upload_resource = self.media_repo.children[b"upload"] + for i in range(number_media): + # file size is 67 Byte + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + self.helper.upload_media( + upload_resource, image_data, tok=user_token, expect_code=200 + ) + + def _check_fields(self, content): + """Checks that all attributes are present in content + """ + for m in content: + self.assertIn("media_id", m) + self.assertIn("media_type", m) + self.assertIn("media_length", m) + self.assertIn("upload_name", m) + self.assertIn("created_ts", m) + self.assertIn("last_access_ts", m) + self.assertIn("quarantined_by", m) + self.assertIn("safe_from_quarantine", m) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index c973521907..3d81618f19 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py
@@ -15,15 +15,22 @@ import json +from mock import Mock + +from twisted.internet import defer + import synapse.rest.admin from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import account from tests import unittest -class IdentityTestCase(unittest.HomeserverTestCase): +class IdentityDisabledTestCase(unittest.HomeserverTestCase): + """Tests that 3PID lookup attempts fail when the HS's config disallows them.""" servlets = [ + account.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, @@ -32,24 +39,111 @@ class IdentityTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() + config["trusted_third_party_id_servers"] = ["testis"] config["enable_3pid_lookup"] = False self.hs = self.setup_test_homeserver(config=config) return self.hs + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + def test_3pid_invite_disabled(self): + request, channel = self.make_request( + b"POST", "/createRoom", b"{}", access_token=self.tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + room_id = channel.json_body["room_id"] + + params = { + "id_server": "testis", + "medium": "email", + "address": "test@example.com", + } + request_data = json.dumps(params) + request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") + request, channel = self.make_request( + b"POST", request_url, request_data, access_token=self.tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"403", channel.result) + def test_3pid_lookup_disabled(self): - self.hs.config.enable_3pid_lookup = False + url = ( + "/_matrix/client/unstable/account/3pid/lookup" + "?id_server=testis&medium=email&address=foo@bar.baz" + ) + request, channel = self.make_request("GET", url, access_token=self.tok) + self.render(request) + self.assertEqual(channel.result["code"], b"403", channel.result) + + def test_3pid_bulk_lookup_disabled(self): + url = "/_matrix/client/unstable/account/3pid/bulk_lookup" + data = { + "id_server": "testis", + "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]], + } + request_data = json.dumps(data) + request, channel = self.make_request( + "POST", url, request_data, access_token=self.tok + ) + self.render(request) + self.assertEqual(channel.result["code"], b"403", channel.result) + + +class IdentityEnabledTestCase(unittest.HomeserverTestCase): + """Tests that 3PID lookup attempts succeed when the HS's config allows them.""" + + servlets = [ + account.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] - self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["enable_3pid_lookup"] = True + config["trusted_third_party_id_servers"] = ["testis"] + + mock_http_client = Mock(spec=["get_json", "post_json_get_json"]) + mock_http_client.get_json.return_value = defer.succeed((200, "{}")) + mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}")) + + self.hs = self.setup_test_homeserver( + config=config, simple_http_client=mock_http_client + ) + + # TODO: This class does not use a singleton to get it's http client + # This should be fixed for easier testing + # https://github.com/matrix-org/synapse-dinsic/issues/26 + self.hs.get_identity_handler().http_client = mock_http_client + + return self.hs + + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + def test_3pid_invite_enabled(self): request, channel = self.make_request( - b"POST", "/createRoom", b"{}", access_token=tok + b"POST", "/createRoom", b"{}", access_token=self.tok ) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) room_id = channel.json_body["room_id"] + # Replace the blacklisting SimpleHttpClient with our mock + self.hs.get_room_member_handler().simple_http_client = Mock( + spec=["get_json", "post_json_get_json"] + ) + self.hs.get_room_member_handler().simple_http_client.get_json.return_value = defer.succeed( + (200, "{}") + ) + params = { "id_server": "testis", "medium": "email", @@ -58,7 +152,44 @@ class IdentityTestCase(unittest.HomeserverTestCase): request_data = json.dumps(params) request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") request, channel = self.make_request( - b"POST", request_url, request_data, access_token=tok + b"POST", request_url, request_data, access_token=self.tok ) self.render(request) - self.assertEquals(channel.result["code"], b"403", channel.result) + + get_json = self.hs.get_identity_handler().http_client.get_json + get_json.assert_called_once_with( + "https://testis/_matrix/identity/api/v1/lookup", + {"address": "test@example.com", "medium": "email"}, + ) + + def test_3pid_lookup_enabled(self): + url = ( + "/_matrix/client/unstable/account/3pid/lookup" + "?id_server=testis&medium=email&address=foo@bar.baz" + ) + request, channel = self.make_request("GET", url, access_token=self.tok) + self.render(request) + + get_json = self.hs.get_simple_http_client().get_json + get_json.assert_called_once_with( + "https://testis/_matrix/identity/api/v1/lookup", + {"address": "foo@bar.baz", "medium": "email"}, + ) + + def test_3pid_bulk_lookup_enabled(self): + url = "/_matrix/client/unstable/account/3pid/bulk_lookup" + data = { + "id_server": "testis", + "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]], + } + request_data = json.dumps(data) + request, channel = self.make_request( + "POST", url, request_data, access_token=self.tok + ) + self.render(request) + + post_json = self.hs.get_simple_http_client().post_json_get_json + post_json.assert_called_once_with( + "https://testis/_matrix/identity/api/v1/bulk_lookup", + {"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]]}, + ) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 7d3773ff78..47c0d5634c 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py
@@ -34,6 +34,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() + config["default_room_version"] = "1" config["retention"] = { "enabled": True, "default_policy": { @@ -243,6 +244,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() + config["default_room_version"] = "1" config["retention"] = { "enabled": True, } diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py new file mode 100644
index 0000000000..1a4ea34292 --- /dev/null +++ b/tests/rest/client/test_room_access_rules.py
@@ -0,0 +1,1083 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import random +import string +from typing import Optional + +from mock import Mock + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset +from synapse.api.errors import SynapseError +from synapse.rest import admin +from synapse.rest.client.v1 import directory, login, room +from synapse.third_party_rules.access_rules import ( + ACCESS_RULES_TYPE, + AccessRules, + RoomAccessRules, +) +from synapse.types import JsonDict, create_requester + +from tests import unittest + + +class RoomAccessTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + config["third_party_event_rules"] = { + "module": "synapse.third_party_rules.access_rules.RoomAccessRules", + "config": { + "domains_forbidden_when_restricted": ["forbidden_domain"], + "id_server": "testis", + }, + } + config["trusted_third_party_id_servers"] = ["testis"] + + def send_invite(destination, room_id, event_id, pdu): + return defer.succeed(pdu) + + def get_json(uri, args={}, headers=None): + address_domain = args["address"].split("@")[1] + return defer.succeed({"hs": address_domain}) + + def post_json_get_json(uri, post_json, args={}, headers=None): + token = "".join(random.choice(string.ascii_letters) for _ in range(10)) + return defer.succeed( + { + "token": token, + "public_keys": [ + { + "public_key": "serverpublickey", + "key_validity_url": "https://testis/pubkey/isvalid", + }, + { + "public_key": "phemeralpublickey", + "key_validity_url": "https://testis/pubkey/ephemeral/isvalid", + }, + ], + "display_name": "f...@b...", + } + ) + + mock_federation_client = Mock(spec=["send_invite"]) + mock_federation_client.send_invite.side_effect = send_invite + + mock_http_client = Mock(spec=["get_json", "post_json_get_json"],) + # Mocking the response for /info on the IS API. + mock_http_client.get_json.side_effect = get_json + # Mocking the response for /store-invite on the IS API. + mock_http_client.post_json_get_json.side_effect = post_json_get_json + self.hs = self.setup_test_homeserver( + config=config, + federation_client=mock_federation_client, + simple_http_client=mock_http_client, + ) + + # TODO: This class does not use a singleton to get it's http client + # This should be fixed for easier testing + # https://github.com/matrix-org/synapse-dinsic/issues/26 + self.hs.get_identity_handler().blacklisting_http_client = mock_http_client + + self.third_party_event_rules = self.hs.get_third_party_event_rules() + + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + self.restricted_room = self.create_room() + self.unrestricted_room = self.create_room(rule=AccessRules.UNRESTRICTED) + self.direct_rooms = [ + self.create_room(direct=True), + self.create_room(direct=True), + self.create_room(direct=True), + ] + + self.invitee_id = self.register_user("invitee", "test") + self.invitee_tok = self.login("invitee", "test") + + self.helper.invite( + room=self.direct_rooms[0], + src=self.user_id, + targ=self.invitee_id, + tok=self.tok, + ) + + def test_create_room_no_rule(self): + """Tests that creating a room with no rule will set the default.""" + room_id = self.create_room() + rule = self.current_rule_in_room(room_id) + + self.assertEqual(rule, AccessRules.RESTRICTED) + + def test_create_room_direct_no_rule(self): + """Tests that creating a direct room with no rule will set the default.""" + room_id = self.create_room(direct=True) + rule = self.current_rule_in_room(room_id) + + self.assertEqual(rule, AccessRules.DIRECT) + + def test_create_room_valid_rule(self): + """Tests that creating a room with a valid rule will set the right.""" + room_id = self.create_room(rule=AccessRules.UNRESTRICTED) + rule = self.current_rule_in_room(room_id) + + self.assertEqual(rule, AccessRules.UNRESTRICTED) + + def test_create_room_invalid_rule(self): + """Tests that creating a room with an invalid rule will set fail.""" + self.create_room(rule=AccessRules.DIRECT, expected_code=400) + + def test_create_room_direct_invalid_rule(self): + """Tests that creating a direct room with an invalid rule will fail. + """ + self.create_room(direct=True, rule=AccessRules.RESTRICTED, expected_code=400) + + def test_create_room_default_power_level_rules(self): + """Tests that a room created with no power level overrides instead uses the dinum + defaults + """ + room_id = self.create_room(direct=True, rule=AccessRules.DIRECT) + power_levels = self.helper.get_state(room_id, "m.room.power_levels", self.tok) + + # Inviting another user should require PL50, even in private rooms + self.assertEqual(power_levels["invite"], 50) + # Sending arbitrary state events should require PL100 + self.assertEqual(power_levels["state_default"], 100) + + def test_create_room_fails_on_incorrect_power_level_rules(self): + """Tests that a room created with power levels lower than that required are rejected""" + modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id) + modified_power_levels["invite"] = 0 + modified_power_levels["state_default"] = 50 + + self.create_room( + direct=True, + rule=AccessRules.DIRECT, + initial_state=[ + {"type": "m.room.power_levels", "content": modified_power_levels} + ], + expected_code=400, + ) + + def test_create_room_with_missing_power_levels_use_default_values(self): + """ + Tests that a room created with custom power levels, but without defining invite or state_default + succeeds, but the missing values are replaced with the defaults. + """ + + # Attempt to create a room without defining "invite" or "state_default" + modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id) + del modified_power_levels["invite"] + del modified_power_levels["state_default"] + room_id = self.create_room( + direct=True, + rule=AccessRules.DIRECT, + initial_state=[ + {"type": "m.room.power_levels", "content": modified_power_levels} + ], + ) + + # This should succeed, but the defaults should be put in place instead + room_power_levels = self.helper.get_state( + room_id, "m.room.power_levels", self.tok + ) + self.assertEqual(room_power_levels["invite"], 50) + self.assertEqual(room_power_levels["state_default"], 100) + + # And now the same test, but using power_levels_content_override instead + # of initial_state (which takes a slightly different codepath) + modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id) + del modified_power_levels["invite"] + del modified_power_levels["state_default"] + room_id = self.create_room( + direct=True, + rule=AccessRules.DIRECT, + power_levels_content_override=modified_power_levels, + ) + + # This should succeed, but the defaults should be put in place instead + room_power_levels = self.helper.get_state( + room_id, "m.room.power_levels", self.tok + ) + self.assertEqual(room_power_levels["invite"], 50) + self.assertEqual(room_power_levels["state_default"], 100) + + def test_existing_room_can_change_power_levels(self): + """Tests that a room created with default power levels can have their power levels + dropped after room creation + """ + # Creates a room with the default power levels + room_id = self.create_room( + direct=True, rule=AccessRules.DIRECT, expected_code=200, + ) + + # Attempt to drop invite and state_default power levels after the fact + room_power_levels = self.helper.get_state( + room_id, "m.room.power_levels", self.tok + ) + room_power_levels["invite"] = 0 + room_power_levels["state_default"] = 50 + self.helper.send_state( + room_id, "m.room.power_levels", room_power_levels, self.tok + ) + + def test_public_room(self): + """Tests that it's only possible to have a room listed in the public room list + if the access rule is restricted. + """ + # Creating a room with the public_chat preset should succeed and set the access + # rule to restricted. + preset_room_id = self.create_room(preset=RoomCreationPreset.PUBLIC_CHAT) + self.assertEqual( + self.current_rule_in_room(preset_room_id), AccessRules.RESTRICTED + ) + + # Creating a room with the public join rule in its initial state should succeed + # and set the access rule to restricted. + init_state_room_id = self.create_room( + initial_state=[ + { + "type": "m.room.join_rules", + "content": {"join_rule": JoinRules.PUBLIC}, + } + ] + ) + self.assertEqual( + self.current_rule_in_room(init_state_room_id), AccessRules.RESTRICTED + ) + + # List preset_room_id in the public room list + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/directory/list/room/%s" % (preset_room_id,), + {"visibility": "public"}, + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + # List init_state_room_id in the public room list + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/directory/list/room/%s" % (init_state_room_id,), + {"visibility": "public"}, + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + # Changing access rule to unrestricted should fail. + self.change_rule_in_room( + preset_room_id, AccessRules.UNRESTRICTED, expected_code=403 + ) + self.change_rule_in_room( + init_state_room_id, AccessRules.UNRESTRICTED, expected_code=403 + ) + + # Changing access rule to direct should fail. + self.change_rule_in_room(preset_room_id, AccessRules.DIRECT, expected_code=403) + self.change_rule_in_room( + init_state_room_id, AccessRules.DIRECT, expected_code=403 + ) + + # Creating a new room with the public_chat preset and an access rule of direct + # should fail. + self.create_room( + preset=RoomCreationPreset.PUBLIC_CHAT, + rule=AccessRules.DIRECT, + expected_code=400, + ) + + # Changing join rule to public in an direct room should fail. + self.change_join_rule_in_room( + self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403 + ) + + def test_restricted(self): + """Tests that in restricted mode we're unable to invite users from blacklisted + servers but can invite other users. + + Also tests that the room can be published to, and removed from, the public room + list. + """ + # We can't invite a user from a forbidden HS. + self.helper.invite( + room=self.restricted_room, + src=self.user_id, + targ="@test:forbidden_domain", + tok=self.tok, + expect_code=403, + ) + + # We can invite a user which HS isn't forbidden. + self.helper.invite( + room=self.restricted_room, + src=self.user_id, + targ="@test:allowed_domain", + tok=self.tok, + expect_code=200, + ) + + # We can't send a 3PID invite to an address that is mapped to a forbidden HS. + self.send_threepid_invite( + address="test@forbidden_domain", + room_id=self.restricted_room, + expected_code=403, + ) + + # We can send a 3PID invite to an address that is mapped to an HS that's not + # forbidden. + self.send_threepid_invite( + address="test@allowed_domain", + room_id=self.restricted_room, + expected_code=200, + ) + + # We are allowed to publish the room to the public room list + url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room + data = {"visibility": "public"} + + request, channel = self.make_request("PUT", url, data, access_token=self.tok) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + # We are allowed to remove the room from the public room list + url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room + data = {"visibility": "private"} + + request, channel = self.make_request("PUT", url, data, access_token=self.tok) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + def test_direct(self): + """Tests that, in direct mode, other users than the initial two can't be invited, + but the following scenario works: + * invited user joins the room + * invited user leaves the room + * room creator re-invites invited user + + Tests that a user from a HS that's in the list of forbidden domains (to use + in restricted mode) can be invited. + + Tests that the room cannot be published to the public room list. + """ + not_invited_user = "@not_invited:forbidden_domain" + + # We can't invite a new user to the room. + self.helper.invite( + room=self.direct_rooms[0], + src=self.user_id, + targ=not_invited_user, + tok=self.tok, + expect_code=403, + ) + + # The invited user can join the room. + self.helper.join( + room=self.direct_rooms[0], + user=self.invitee_id, + tok=self.invitee_tok, + expect_code=200, + ) + + # The invited user can leave the room. + self.helper.leave( + room=self.direct_rooms[0], + user=self.invitee_id, + tok=self.invitee_tok, + expect_code=200, + ) + + # The invited user can be re-invited to the room. + self.helper.invite( + room=self.direct_rooms[0], + src=self.user_id, + targ=self.invitee_id, + tok=self.tok, + expect_code=200, + ) + + # If we're alone in the room and have always been the only member, we can invite + # someone. + self.helper.invite( + room=self.direct_rooms[1], + src=self.user_id, + targ=not_invited_user, + tok=self.tok, + expect_code=200, + ) + + # Disable the 3pid invite ratelimiter + burst = self.hs.config.rc_third_party_invite.burst_count + per_second = self.hs.config.rc_third_party_invite.per_second + self.hs.config.rc_third_party_invite.burst_count = 10 + self.hs.config.rc_third_party_invite.per_second = 0.1 + + # We can't send a 3PID invite to a room that already has two members. + self.send_threepid_invite( + address="test@allowed_domain", + room_id=self.direct_rooms[0], + expected_code=403, + ) + + # We can't send a 3PID invite to a room that already has a pending invite. + self.send_threepid_invite( + address="test@allowed_domain", + room_id=self.direct_rooms[1], + expected_code=403, + ) + + # We can send a 3PID invite to a room in which we've always been the only member. + self.send_threepid_invite( + address="test@forbidden_domain", + room_id=self.direct_rooms[2], + expected_code=200, + ) + + # We can send a 3PID invite to a room in which there's a 3PID invite. + self.send_threepid_invite( + address="test@forbidden_domain", + room_id=self.direct_rooms[2], + expected_code=403, + ) + + self.hs.config.rc_third_party_invite.burst_count = burst + self.hs.config.rc_third_party_invite.per_second = per_second + + # We can't publish the room to the public room list + url = "/_matrix/client/r0/directory/list/room/%s" % self.direct_rooms[0] + data = {"visibility": "public"} + + request, channel = self.make_request("PUT", url, data, access_token=self.tok) + self.render(request) + self.assertEqual(channel.code, 403, channel.result) + + def test_unrestricted(self): + """Tests that, in unrestricted mode, we can invite whoever we want, but we can + only change the power level of users that wouldn't be forbidden in restricted + mode. + + Tests that the room cannot be published to the public room list. + """ + # We can invite + self.helper.invite( + room=self.unrestricted_room, + src=self.user_id, + targ="@test:forbidden_domain", + tok=self.tok, + expect_code=200, + ) + + self.helper.invite( + room=self.unrestricted_room, + src=self.user_id, + targ="@test:not_forbidden_domain", + tok=self.tok, + expect_code=200, + ) + + # We can send a 3PID invite to an address that is mapped to a forbidden HS. + self.send_threepid_invite( + address="test@forbidden_domain", + room_id=self.unrestricted_room, + expected_code=200, + ) + + # We can send a 3PID invite to an address that is mapped to an HS that's not + # forbidden. + self.send_threepid_invite( + address="test@allowed_domain", + room_id=self.unrestricted_room, + expected_code=200, + ) + + # We can send a power level event that doesn't redefine the default PL or set a + # non-default PL for a user that would be forbidden in restricted mode. + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.PowerLevels, + body={"users": {self.user_id: 100, "@test:not_forbidden_domain": 10}}, + tok=self.tok, + expect_code=200, + ) + + # We can't send a power level event that redefines the default PL and doesn't set + # a non-default PL for a user that would be forbidden in restricted mode. + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.PowerLevels, + body={ + "users": {self.user_id: 100, "@test:not_forbidden_domain": 10}, + "users_default": 10, + }, + tok=self.tok, + expect_code=403, + ) + + # We can't send a power level event that doesn't redefines the default PL but sets + # a non-default PL for a user that would be forbidden in restricted mode. + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.PowerLevels, + body={"users": {self.user_id: 100, "@test:forbidden_domain": 10}}, + tok=self.tok, + expect_code=403, + ) + + # We can't publish the room to the public room list + url = "/_matrix/client/r0/directory/list/room/%s" % self.unrestricted_room + data = {"visibility": "public"} + + request, channel = self.make_request("PUT", url, data, access_token=self.tok) + self.render(request) + self.assertEqual(channel.code, 403, channel.result) + + def test_change_rules(self): + """Tests that we can only change the current rule from restricted to + unrestricted. + """ + # We can't change the rule from restricted to direct. + self.change_rule_in_room( + room_id=self.restricted_room, new_rule=AccessRules.DIRECT, expected_code=403 + ) + + # We can change the rule from restricted to unrestricted. + # Note that this changes self.restricted_room to an unrestricted room + self.change_rule_in_room( + room_id=self.restricted_room, + new_rule=AccessRules.UNRESTRICTED, + expected_code=200, + ) + + # We can't change the rule from unrestricted to restricted. + self.change_rule_in_room( + room_id=self.unrestricted_room, + new_rule=AccessRules.RESTRICTED, + expected_code=403, + ) + + # We can't change the rule from unrestricted to direct. + self.change_rule_in_room( + room_id=self.unrestricted_room, + new_rule=AccessRules.DIRECT, + expected_code=403, + ) + + # We can't change the rule from direct to restricted. + self.change_rule_in_room( + room_id=self.direct_rooms[0], + new_rule=AccessRules.RESTRICTED, + expected_code=403, + ) + + # We can't change the rule from direct to unrestricted. + self.change_rule_in_room( + room_id=self.direct_rooms[0], + new_rule=AccessRules.UNRESTRICTED, + expected_code=403, + ) + + # We can't publish a room to the public room list and then change its rule to + # unrestricted + + # Create a restricted room + test_room_id = self.create_room(rule=AccessRules.RESTRICTED) + + # Publish the room to the public room list + url = "/_matrix/client/r0/directory/list/room/%s" % test_room_id + data = {"visibility": "public"} + + request, channel = self.make_request("PUT", url, data, access_token=self.tok) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + # Attempt to switch the room to "unrestricted" + self.change_rule_in_room( + room_id=test_room_id, new_rule=AccessRules.UNRESTRICTED, expected_code=403 + ) + + # Attempt to switch the room to "direct" + self.change_rule_in_room( + room_id=test_room_id, new_rule=AccessRules.DIRECT, expected_code=403 + ) + + def test_change_room_avatar(self): + """Tests that changing the room avatar is always allowed unless the room is a + direct chat, in which case it's forbidden. + """ + + avatar_content = { + "info": {"h": 398, "mimetype": "image/jpeg", "size": 31037, "w": 394}, + "url": "mxc://example.org/JWEIFJgwEIhweiWJE", + } + + self.helper.send_state( + room_id=self.restricted_room, + event_type=EventTypes.RoomAvatar, + body=avatar_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.RoomAvatar, + body=avatar_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.direct_rooms[0], + event_type=EventTypes.RoomAvatar, + body=avatar_content, + tok=self.tok, + expect_code=403, + ) + + def test_change_room_name(self): + """Tests that changing the room name is always allowed unless the room is a direct + chat, in which case it's forbidden. + """ + + name_content = {"name": "My super room"} + + self.helper.send_state( + room_id=self.restricted_room, + event_type=EventTypes.Name, + body=name_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.Name, + body=name_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.direct_rooms[0], + event_type=EventTypes.Name, + body=name_content, + tok=self.tok, + expect_code=403, + ) + + def test_change_room_topic(self): + """Tests that changing the room topic is always allowed unless the room is a + direct chat, in which case it's forbidden. + """ + + topic_content = {"topic": "Welcome to this room"} + + self.helper.send_state( + room_id=self.restricted_room, + event_type=EventTypes.Topic, + body=topic_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.Topic, + body=topic_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.direct_rooms[0], + event_type=EventTypes.Topic, + body=topic_content, + tok=self.tok, + expect_code=403, + ) + + def test_revoke_3pid_invite_direct(self): + """Tests that revoking a 3PID invite doesn't cause the room access rules module to + confuse the revokation as a new 3PID invite. + """ + invite_token = "sometoken" + + invite_body = { + "display_name": "ker...@exa...", + "public_keys": [ + { + "key_validity_url": "https://validity_url", + "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA", + }, + { + "key_validity_url": "https://validity_url", + "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I", + }, + ], + "key_validity_url": "https://validity_url", + "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA", + } + + self.send_state_with_state_key( + room_id=self.direct_rooms[1], + event_type=EventTypes.ThirdPartyInvite, + state_key=invite_token, + body=invite_body, + tok=self.tok, + ) + + self.send_state_with_state_key( + room_id=self.direct_rooms[1], + event_type=EventTypes.ThirdPartyInvite, + state_key=invite_token, + body={}, + tok=self.tok, + ) + + invite_token = "someothertoken" + + self.send_state_with_state_key( + room_id=self.direct_rooms[1], + event_type=EventTypes.ThirdPartyInvite, + state_key=invite_token, + body=invite_body, + tok=self.tok, + ) + + def test_check_event_allowed(self): + """Tests that RoomAccessRules.check_event_allowed behaves accordingly. + + It tests that: + * forbidden users cannot join restricted rooms. + * forbidden users can only join unrestricted rooms if they have an invite. + """ + event_creator = self.hs.get_event_creation_handler() + + # Test that forbidden users cannot join restricted rooms + requester = create_requester(self.user_id) + allowed_requester = create_requester("@user:allowed_domain") + forbidden_requester = create_requester("@user:forbidden_domain") + + # Assert a join event from a forbidden user to a restricted room is rejected + self.get_failure( + event_creator.create_event( + forbidden_requester, + { + "type": EventTypes.Member, + "room_id": self.restricted_room, + "sender": forbidden_requester.user.to_string(), + "content": {"membership": Membership.JOIN}, + "state_key": forbidden_requester.user.to_string(), + }, + ), + SynapseError, + ) + + # A join event from an non-forbidden user to a restricted room is allowed + self.get_success( + event_creator.create_event( + allowed_requester, + { + "type": EventTypes.Member, + "room_id": self.restricted_room, + "sender": allowed_requester.user.to_string(), + "content": {"membership": Membership.JOIN}, + "state_key": allowed_requester.user.to_string(), + }, + ) + ) + + # Test that forbidden users can only join unrestricted rooms if they have an invite + + # A forbidden user without an invite should not be able to join an unrestricted room + self.get_failure( + event_creator.create_event( + forbidden_requester, + { + "type": EventTypes.Member, + "room_id": self.unrestricted_room, + "sender": forbidden_requester.user.to_string(), + "content": {"membership": Membership.JOIN}, + "state_key": forbidden_requester.user.to_string(), + }, + ), + SynapseError, + ) + + # However, if we then invite this user... + self.helper.invite( + room=self.unrestricted_room, + src=requester.user.to_string(), + targ=forbidden_requester.user.to_string(), + tok=self.tok, + ) + + # And create another join event, making sure that its context states it's coming + # in after the above invite was made... + # Then the forbidden user should be able to join! + self.get_success( + event_creator.create_event( + forbidden_requester, + { + "type": EventTypes.Member, + "room_id": self.unrestricted_room, + "sender": forbidden_requester.user.to_string(), + "content": {"membership": Membership.JOIN}, + "state_key": forbidden_requester.user.to_string(), + }, + ) + ) + + def test_freezing_a_room(self): + """Tests that the power levels in a room change to prevent new events from + non-admin users when the last admin of a room leaves. + """ + + def freeze_room_with_id_and_power_levels( + room_id: str, custom_power_levels_content: Optional[JsonDict] = None, + ): + # Invite a user to the room, they join with PL 0 + self.helper.invite( + room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok, + ) + + # Invitee joins the room + self.helper.join( + room=room_id, user=self.invitee_id, tok=self.invitee_tok, + ) + + if not custom_power_levels_content: + # Retrieve the room's current power levels event content + power_levels = self.helper.get_state( + room_id=room_id, event_type="m.room.power_levels", tok=self.tok, + ) + else: + power_levels = custom_power_levels_content + + # Override the room's power levels with the given power levels content + self.helper.send_state( + room_id=room_id, + event_type="m.room.power_levels", + body=custom_power_levels_content, + tok=self.tok, + ) + + # Ensure that the invitee leaving the room does not change the power levels + self.helper.leave( + room=room_id, user=self.invitee_id, tok=self.invitee_tok, + ) + + # Retrieve the new power levels of the room + new_power_levels = self.helper.get_state( + room_id=room_id, event_type="m.room.power_levels", tok=self.tok, + ) + + # Ensure they have not changed + self.assertDictEqual(power_levels, new_power_levels) + + # Invite the user back again + self.helper.invite( + room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok, + ) + + # Invitee joins the room + self.helper.join( + room=room_id, user=self.invitee_id, tok=self.invitee_tok, + ) + + # Now the admin leaves the room + self.helper.leave( + room=room_id, user=self.user_id, tok=self.tok, + ) + + # Check the power levels again + new_power_levels = self.helper.get_state( + room_id=room_id, event_type="m.room.power_levels", tok=self.invitee_tok, + ) + + # Ensure that the new power levels prevent anyone but admins from sending + # certain events + self.assertEquals(new_power_levels["state_default"], 100) + self.assertEquals(new_power_levels["events_default"], 100) + self.assertEquals(new_power_levels["kick"], 100) + self.assertEquals(new_power_levels["invite"], 100) + self.assertEquals(new_power_levels["ban"], 100) + self.assertEquals(new_power_levels["redact"], 100) + self.assertDictEqual(new_power_levels["events"], {}) + self.assertDictEqual(new_power_levels["users"], {self.user_id: 100}) + + # Ensure new users entering the room aren't going to immediately become admins + self.assertEquals(new_power_levels["users_default"], 0) + + # Test that freezing a room with the default power level state event content works + room1 = self.create_room() + freeze_room_with_id_and_power_levels(room1) + + # Test that freezing a room with a power level state event that is missing + # `state_default` and `event_default` keys behaves as expected + room2 = self.create_room() + freeze_room_with_id_and_power_levels( + room2, + { + "ban": 50, + "events": { + "m.room.avatar": 50, + "m.room.canonical_alias": 50, + "m.room.history_visibility": 100, + "m.room.name": 50, + "m.room.power_levels": 100, + }, + "invite": 0, + "kick": 50, + "redact": 50, + "users": {self.user_id: 100}, + "users_default": 0, + # Explicitly remove `state_default` and `event_default` keys + }, + ) + + # Test that freezing a room with a power level state event that is *additionally* + # missing `ban`, `invite`, `kick` and `redact` keys behaves as expected + room3 = self.create_room() + freeze_room_with_id_and_power_levels( + room3, + { + "events": { + "m.room.avatar": 50, + "m.room.canonical_alias": 50, + "m.room.history_visibility": 100, + "m.room.name": 50, + "m.room.power_levels": 100, + }, + "users": {self.user_id: 100}, + "users_default": 0, + # Explicitly remove `state_default` and `event_default` keys + # Explicitly remove `ban`, `invite`, `kick` and `redact` keys + }, + ) + + def create_room( + self, + direct=False, + rule=None, + preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT, + initial_state=None, + power_levels_content_override=None, + expected_code=200, + ): + content = {"is_direct": direct, "preset": preset} + + if rule: + content["initial_state"] = [ + {"type": ACCESS_RULES_TYPE, "state_key": "", "content": {"rule": rule}} + ] + + if initial_state: + if "initial_state" not in content: + content["initial_state"] = [] + + content["initial_state"] += initial_state + + if power_levels_content_override: + content["power_levels_content_override"] = power_levels_content_override + + request, channel = self.make_request( + "POST", "/_matrix/client/r0/createRoom", content, access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + if expected_code == 200: + return channel.json_body["room_id"] + + def current_rule_in_room(self, room_id): + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE), + access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body["rule"] + + def change_rule_in_room(self, room_id, new_rule, expected_code=200): + data = {"rule": new_rule} + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE), + json.dumps(data), + access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200): + data = {"join_rule": new_join_rule} + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules), + json.dumps(data), + access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + def send_threepid_invite(self, address, room_id, expected_code=200): + params = {"id_server": "testis", "medium": "email", "address": address} + + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/invite" % room_id, + json.dumps(params), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + def send_state_with_state_key( + self, room_id, event_type, state_key, body, tok, expect_code=200 + ): + path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % ( + room_id, + event_type, + state_key, + ) + + request, channel = self.make_request( + "PUT", path, json.dumps(body), access_token=tok + ) + self.render(request) + + self.assertEqual(channel.code, expect_code, channel.result) + + return channel.json_body diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 5d987a30c7..eff17bdf61 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py
@@ -518,8 +518,12 @@ class JWTTestCase(unittest.HomeserverTestCase): self.hs.config.jwt_algorithm = self.jwt_algorithm return self.hs - def jwt_encode(self, token, secret=jwt_secret): - return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii") + def jwt_encode(self, token: str, secret: str = jwt_secret) -> str: + # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. + result = jwt.encode(token, secret, self.jwt_algorithm) + if isinstance(result, bytes): + return result.decode("ascii") + return result def jwt_login(self, *args): params = json.dumps( @@ -725,8 +729,12 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): self.hs.config.jwt_algorithm = "RS256" return self.hs - def jwt_encode(self, token, secret=jwt_privatekey): - return jwt.encode(token, secret, "RS256").decode("ascii") + def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str: + # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. + result = jwt.encode(token, secret, "RS256") + if isinstance(result, bytes): + return result.decode("ascii") + return result def jwt_login(self, *args): params = json.dumps( diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 2fc3a60fc5..3a1f7c5da9 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py
@@ -19,8 +19,12 @@ import datetime import json import os +from mock import Mock + import pkg_resources +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import LoginType from synapse.api.errors import Codes @@ -55,6 +59,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.hs.config.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", ) self.hs.get_datastore().services_cache.append(appservice) @@ -87,14 +92,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.json_body["error"], "Invalid password") - def test_POST_bad_username(self): - request_data = json.dumps({"username": 777, "password": "monkey"}) - request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) - - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid username") - def test_POST_user_valid(self): user_id = "@kermit:test" device_id = "frogfone" @@ -303,6 +300,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(channel.json_body.get("sid")) +class RegisterHideProfileTestCase(unittest.HomeserverTestCase): + + servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] + + def make_homeserver(self, reactor, clock): + + self.url = b"/_matrix/client/r0/register" + + config = self.default_config() + config["enable_registration"] = True + config["show_users_in_user_directory"] = False + config["replicate_user_profiles_to"] = ["fakeserver"] + + mock_http_client = Mock(spec=["get_json", "post_json_get_json"]) + mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}")) + + self.hs = self.setup_test_homeserver( + config=config, simple_http_client=mock_http_client + ) + + return self.hs + + def test_profile_hidden(self): + user_id = self.register_user("kermit", "monkey") + + post_json = self.hs.get_simple_http_client().post_json_get_json + + # We expect post_json_get_json to have been called twice: once with the original + # profile and once with the None profile resulting from the request to hide it + # from the user directory. + self.assertEqual(post_json.call_count, 2, post_json.call_args_list) + + # Get the args (and not kwargs) passed to post_json. + args = post_json.call_args[0] + # Make sure the last call was attempting to replicate profiles. + split_uri = args[0].split("/") + self.assertEqual(split_uri[len(split_uri) - 1], "replicate_profiles", args[0]) + # Make sure the last profile update was overriding the user's profile to None. + self.assertEqual(args[1]["batch"][user_id], None, args[1]) + + class AccountValidityTestCase(unittest.HomeserverTestCase): servlets = [ @@ -312,6 +350,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): sync.register_servlets, logout.register_servlets, account_validity.register_servlets, + account.register_servlets, ] def make_homeserver(self, reactor, clock): @@ -437,6 +476,155 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) +class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.client.v1.profile.register_servlets, + synapse.rest.client.v1.room.register_servlets, + synapse.rest.client.v2_alpha.user_directory.register_servlets, + login.register_servlets, + register.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + account_validity.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Set accounts to expire after a week + config["enable_registration"] = True + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + } + config["replicate_user_profiles_to"] = "test.is" + + # Mock homeserver requests to an identity server + mock_http_client = Mock(spec=["post_json_get_json"]) + mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}")) + + self.hs = self.setup_test_homeserver( + config=config, simple_http_client=mock_http_client + ) + + return self.hs + + def test_expired_user_in_directory(self): + """Test that an expired user is hidden in the user directory""" + # Create an admin user to search the user directory + admin_id = self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + # Ensure the admin never expires + url = "/_matrix/client/unstable/admin/account_validity/validity" + params = { + "user_id": admin_id, + "expiration_ts": 999999999999, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + request, channel = self.make_request( + b"POST", url, request_data, access_token=admin_tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Mock the homeserver's HTTP client + post_json = self.hs.get_simple_http_client().post_json_get_json + + # Create a user + username = "kermit" + user_id = self.register_user(username, "monkey") + self.login(username, "monkey") + self.get_success( + self.hs.get_datastore().set_profile_displayname(username, "mr.kermit", 1) + ) + + # Check that a full profile for this user is replicated + self.assertIsNotNone(post_json.call_args, post_json.call_args) + payload = post_json.call_args[0][1] + batch = payload.get("batch") + + self.assertIsNotNone(batch, batch) + self.assertEquals(len(batch), 1, batch) + + replicated_user_id = list(batch.keys())[0] + self.assertEquals(replicated_user_id, user_id, replicated_user_id) + + # There was replicated information about our user + # Check that it's not None + replicated_content = batch[user_id] + self.assertIsNotNone(replicated_content) + + # Expire the user + url = "/_matrix/client/unstable/admin/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + request, channel = self.make_request( + b"POST", url, request_data, access_token=admin_tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Wait for the background job to run which hides expired users in the directory + self.reactor.advance(60 * 60 * 1000) + + # Check if the homeserver has replicated the user's profile to the identity server + self.assertIsNotNone(post_json.call_args, post_json.call_args) + payload = post_json.call_args[0][1] + batch = payload.get("batch") + + self.assertIsNotNone(batch, batch) + self.assertEquals(len(batch), 1, batch) + + replicated_user_id = list(batch.keys())[0] + self.assertEquals(replicated_user_id, user_id, replicated_user_id) + + # There was replicated information about our user + # Check that it's None, signifying that the user should be removed from the user + # directory because they were expired + replicated_content = batch[user_id] + self.assertIsNone(replicated_content) + + # Now renew the user, and check they get replicated again to the identity server + url = "/_matrix/client/unstable/admin/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 99999999999, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + request, channel = self.make_request( + b"POST", url, request_data, access_token=admin_tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.pump(10) + self.reactor.advance(10) + self.pump() + + # Check if the homeserver has replicated the user's profile to the identity server + post_json = self.hs.get_simple_http_client().post_json_get_json + self.assertNotEquals(post_json.call_args, None, post_json.call_args) + payload = post_json.call_args[0][1] + batch = payload.get("batch") + self.assertNotEquals(batch, None, batch) + self.assertEquals(len(batch), 1, batch) + replicated_user_id = list(batch.keys())[0] + self.assertEquals(replicated_user_id, user_id, replicated_user_id) + + # There was replicated information about our user + # Check that it's not None, signifying that the user is back in the user + # directory + replicated_content = batch[user_id] + self.assertIsNotNone(replicated_content) + + class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): servlets = [ @@ -499,8 +687,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): (user_id, tok) = self.create_user() - # Move 6 days forward. This should trigger a renewal email to be sent. - self.reactor.advance(datetime.timedelta(days=6).total_seconds()) + # Move 5 days forward. This should trigger a renewal email to be sent. + self.reactor.advance(datetime.timedelta(days=5).total_seconds()) self.assertEqual(len(self.email_attempts), 1) # Retrieving the URL from the email is too much pain for now, so we @@ -512,14 +700,33 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) # Check that we're getting HTML back. - content_type = None - for header in channel.result.get("headers", []): - if header[0] == b"Content-Type": - content_type = header[1] - self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result) + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) # Check that the HTML we're getting is the one we expect on a successful renewal. - expected_html = self.hs.config.account_validity.account_renewed_html_content + expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + expected_html = self.hs.config.account_validity_account_renewed_template.render( + expiration_ts=expiration_ts + ) + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + + # Move 1 day forward. Try to renew with the same token again. + url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token + request, channel = self.make_request(b"GET", url) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Check that we're getting HTML back. + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) + + # Check that the HTML we're getting is the one we expect when reusing a + # token. The account expiration date should not have changed. + expected_html = self.hs.config.account_validity_account_previously_renewed_template.render( + expiration_ts=expiration_ts + ) self.assertEqual( channel.result["body"], expected_html.encode("utf8"), channel.result ) @@ -541,15 +748,12 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"404", channel.result) # Check that we're getting HTML back. - content_type = None - for header in channel.result.get("headers", []): - if header[0] == b"Content-Type": - content_type = header[1] - self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result) + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) # Check that the HTML we're getting is the one we expect when using an # invalid/unknown token. - expected_html = self.hs.config.account_validity.invalid_token_html_content + expected_html = self.hs.config.account_validity_invalid_token_template.render() self.assertEqual( channel.result["body"], expected_html.encode("utf8"), channel.result ) @@ -587,7 +791,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "POST", "account/deactivate", request_data, access_token=tok ) self.render(request) - self.assertEqual(request.code, 200) + self.assertEqual(request.code, 200, channel.result) self.reactor.advance(datetime.timedelta(days=8).total_seconds()) @@ -660,7 +864,12 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): config["account_validity"] = {"enabled": False} self.hs = self.setup_test_homeserver(config=config) - self.hs.config.account_validity.period = self.validity_period + + # We need to set these directly, instead of in the homeserver config dict above. + # This is due to account validity-related config options not being read by + # Synapse when account_validity.enabled is False. + self.hs.get_datastore()._account_validity_period = self.validity_period + self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta self.store = self.hs.get_datastore() @@ -674,8 +883,6 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): """ user_id = self.register_user("kermit_delta", "user") - self.hs.config.account_validity.startup_job_max_delta = self.max_delta - now_ms = self.hs.clock.time_msec() self.get_success(self.store._set_expiration_date_when_missing()) diff --git a/tests/rulecheck/__init__.py b/tests/rulecheck/__init__.py new file mode 100644
index 0000000000..a354d38ca8 --- /dev/null +++ b/tests/rulecheck/__init__.py
@@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py new file mode 100644
index 0000000000..1accc70dc9 --- /dev/null +++ b/tests/rulecheck/test_domainrulecheck.py
@@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json + +import synapse.rest.admin +from synapse.config._base import ConfigError +from synapse.rest.client.v1 import login, room +from synapse.rulecheck.domain_rule_checker import DomainRuleChecker + +from tests import unittest +from tests.server import make_request, render + + +class DomainRuleCheckerTestCase(unittest.TestCase): + def test_allowed(self): + config = { + "default": False, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + }, + "domains_prevented_from_being_invited_to_published_rooms": ["target_two"], + } + check = DomainRuleChecker(config) + self.assertTrue( + check.user_may_invite( + "test:source_one", "test:target_one", None, "room", False + ) + ) + self.assertTrue( + check.user_may_invite( + "test:source_one", "test:target_two", None, "room", False + ) + ) + self.assertTrue( + check.user_may_invite( + "test:source_two", "test:target_two", None, "room", False + ) + ) + + # User can invite internal user to a published room + self.assertTrue( + check.user_may_invite( + "test:source_one", "test1:target_one", None, "room", False, True + ) + ) + + # User can invite external user to a non-published room + self.assertTrue( + check.user_may_invite( + "test:source_one", "test:target_two", None, "room", False, False + ) + ) + + def test_disallowed(self): + config = { + "default": True, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + "source_four": [], + }, + } + check = DomainRuleChecker(config) + self.assertFalse( + check.user_may_invite( + "test:source_one", "test:target_three", None, "room", False + ) + ) + self.assertFalse( + check.user_may_invite( + "test:source_two", "test:target_three", None, "room", False + ) + ) + self.assertFalse( + check.user_may_invite( + "test:source_two", "test:target_one", None, "room", False + ) + ) + self.assertFalse( + check.user_may_invite( + "test:source_four", "test:target_one", None, "room", False + ) + ) + + # User cannot invite external user to a published room + self.assertTrue( + check.user_may_invite( + "test:source_one", "test:target_two", None, "room", False, True + ) + ) + + def test_default_allow(self): + config = { + "default": True, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + }, + } + check = DomainRuleChecker(config) + self.assertTrue( + check.user_may_invite( + "test:source_three", "test:target_one", None, "room", False + ) + ) + + def test_default_deny(self): + config = { + "default": False, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + }, + } + check = DomainRuleChecker(config) + self.assertFalse( + check.user_may_invite( + "test:source_three", "test:target_one", None, "room", False + ) + ) + + def test_config_parse(self): + config = { + "default": False, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + }, + } + self.assertEquals(config, DomainRuleChecker.parse_config(config)) + + def test_config_parse_failure(self): + config = { + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + } + } + self.assertRaises(ConfigError, DomainRuleChecker.parse_config, config) + + +class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + hijack_auth = False + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["trusted_third_party_id_servers"] = ["localhost"] + + config["spam_checker"] = { + "module": "synapse.rulecheck.domain_rule_checker.DomainRuleChecker", + "config": { + "default": True, + "domain_mapping": {}, + "can_only_join_rooms_with_invite": True, + "can_only_create_one_to_one_rooms": True, + "can_only_invite_during_room_creation": True, + "can_invite_by_third_party_id": False, + }, + } + + hs = self.setup_test_homeserver(config=config) + return hs + + def prepare(self, reactor, clock, hs): + self.admin_user_id = self.register_user("admin_user", "pass", admin=True) + self.admin_access_token = self.login("admin_user", "pass") + + self.normal_user_id = self.register_user("normal_user", "pass", admin=False) + self.normal_access_token = self.login("normal_user", "pass") + + self.other_user_id = self.register_user("other_user", "pass", admin=False) + + def test_admin_can_create_room(self): + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + def test_normal_user_cannot_create_empty_room(self): + channel = self._create_room(self.normal_access_token) + assert channel.result["code"] == b"403", channel.result + + def test_normal_user_cannot_create_room_with_multiple_invites(self): + channel = self._create_room( + self.normal_access_token, + content={"invite": [self.other_user_id, self.admin_user_id]}, + ) + assert channel.result["code"] == b"403", channel.result + + # Test that it correctly counts both normal and third party invites + channel = self._create_room( + self.normal_access_token, + content={ + "invite": [self.other_user_id], + "invite_3pid": [{"medium": "email", "address": "foo@example.com"}], + }, + ) + assert channel.result["code"] == b"403", channel.result + + # Test that it correctly rejects third party invites + channel = self._create_room( + self.normal_access_token, + content={ + "invite": [], + "invite_3pid": [{"medium": "email", "address": "foo@example.com"}], + }, + ) + assert channel.result["code"] == b"403", channel.result + + def test_normal_user_can_room_with_single_invites(self): + channel = self._create_room( + self.normal_access_token, content={"invite": [self.other_user_id]} + ) + assert channel.result["code"] == b"200", channel.result + + def test_cannot_join_public_room(self): + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + room_id = channel.json_body["room_id"] + + self.helper.join( + room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=403 + ) + + def test_can_join_invited_room(self): + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + room_id = channel.json_body["room_id"] + + self.helper.invite( + room_id, + src=self.admin_user_id, + targ=self.normal_user_id, + tok=self.admin_access_token, + ) + + self.helper.join( + room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200 + ) + + def test_cannot_invite(self): + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + room_id = channel.json_body["room_id"] + + self.helper.invite( + room_id, + src=self.admin_user_id, + targ=self.normal_user_id, + tok=self.admin_access_token, + ) + + self.helper.join( + room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200 + ) + + self.helper.invite( + room_id, + src=self.normal_user_id, + targ=self.other_user_id, + tok=self.normal_access_token, + expect_code=403, + ) + + def test_cannot_3pid_invite(self): + """Test that unbound 3pid invites get rejected. + """ + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + room_id = channel.json_body["room_id"] + + self.helper.invite( + room_id, + src=self.admin_user_id, + targ=self.normal_user_id, + tok=self.admin_access_token, + ) + + self.helper.join( + room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200 + ) + + self.helper.invite( + room_id, + src=self.normal_user_id, + targ=self.other_user_id, + tok=self.normal_access_token, + expect_code=403, + ) + + request, channel = self.make_request( + "POST", + "rooms/%s/invite" % (room_id), + {"address": "foo@bar.com", "medium": "email", "id_server": "localhost"}, + access_token=self.normal_access_token, + ) + self.render(request) + self.assertEqual(channel.code, 403, channel.result["body"]) + + def _create_room(self, token, content={}): + path = "/_matrix/client/r0/createRoom?access_token=%s" % (token,) + + request, channel = make_request( + self.hs.get_reactor(), + "POST", + path, + content=json.dumps(content).encode("utf8"), + ) + render(request, self.resource, self.hs.get_reactor()) + + return channel diff --git a/tests/server.py b/tests/server.py
index 4d33b84097..b97003fa5a 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -380,7 +380,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool._runWithConnection, func, *args, - **kwargs + **kwargs, ) def runInteraction(interaction, *args, **kwargs): @@ -390,7 +390,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool._runInteraction, interaction, *args, - **kwargs + **kwargs, ) pool.runWithConnection = runWithConnection @@ -571,12 +571,10 @@ def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol reactor factory: The connecting factory to build. """ - factory = reactor.tcpClients[client_id][2] + factory = reactor.tcpClients.pop(client_id)[2] client = factory.buildProtocol(None) server = AccumulatingProtocol() server.makeConnection(FakeTransport(client, reactor)) client.makeConnection(FakeTransport(server, reactor)) - reactor.tcpClients.pop(client_id) - return client, server diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 080761d1d2..5a1e5c4e66 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py
@@ -22,7 +22,7 @@ import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client.v1 import login, room from synapse.storage import prepare_database -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from tests.unittest import HomeserverTestCase @@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") - self.requester = Requester(self.user, None, False, False, None, None) + self.requester = create_requester(self.user) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] @@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") - self.requester = Requester(self.user, None, False, False, None, None) + self.requester = create_requester(self.user) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 755c70db31..e96ca1c8ca 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py
@@ -412,7 +412,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/admin/users/" + self.user_id, access_token=access_token, - **make_request_args + **make_request_args, ) request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza") diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 3957471f3f..7691f2d790 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py
@@ -14,7 +14,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, generate_latest -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from tests.unittest import HomeserverTestCase @@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): room_creator = self.hs.get_room_creation_handler() user = UserID("alice", "test") - requester = Requester(user, None, False, False, None, None) + requester = create_requester(user) # Real events, forward extremities events = [(3, 2), (6, 2), (4, 6)] diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 7e7f1286d9..fe37d2ed5a 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py
@@ -39,7 +39,7 @@ class DataStoreTestCase(unittest.TestCase): ) yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) yield defer.ensureDeferred( - self.store.set_profile_displayname(self.user.localpart, self.displayname) + self.store.set_profile_displayname(self.user.localpart, self.displayname, 1) ) users, total = yield defer.ensureDeferred( diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 3fd0a38cf5..7a38022e71 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py
@@ -36,7 +36,7 @@ class ProfileStoreTestCase(unittest.TestCase): yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) yield defer.ensureDeferred( - self.store.set_profile_displayname(self.u_frank.localpart, "Frank") + self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1) ) self.assertEquals( @@ -54,7 +54,7 @@ class ProfileStoreTestCase(unittest.TestCase): yield defer.ensureDeferred( self.store.set_profile_avatar_url( - self.u_frank.localpart, "http://my.site/here" + self.u_frank.localpart, "http://my.site/here", 1 ) ) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 6b582771fe..c8c7a90e5d 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py
@@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase): self.store.get_user_by_access_token(self.tokens[1]) ) - self.assertDictContainsSubset( - {"name": self.user_id, "device_id": self.device_id}, result - ) - - self.assertTrue("token_id" in result) + self.assertEqual(result.user_id, self.user_id) + self.assertEqual(result.device_id, self.device_id) + self.assertIsNotNone(result.token_id) @defer.inlineCallbacks def test_user_delete_access_tokens(self): @@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase): user = yield defer.ensureDeferred( self.store.get_user_by_access_token(self.tokens[0]) ) - self.assertEqual(self.user_id, user["name"]) + self.assertEqual(self.user_id, user.user_id) # now delete the rest yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id)) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 12ccc1f53e..ff972daeaa 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py
@@ -19,7 +19,7 @@ from unittest.mock import Mock from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from tests import unittest from tests.test_utils import event_injection @@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Now let's create a room, which will insert a membership user = UserID("alice", "test") - requester = Requester(user, None, False, False, None, None) + requester = create_requester(user) self.get_success(self.room_creator.create_room(requester, {})) # Register the background update to run again. diff --git a/tests/test_federation.py b/tests/test_federation.py
index d39e792580..1ce4ea3a01 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py
@@ -20,7 +20,7 @@ from twisted.internet.defer import succeed from synapse.api.errors import FederationError from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from synapse.util import Clock from synapse.util.retryutils import NotRetryingDestination @@ -43,7 +43,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) user_id = UserID("us", "test") - our_user = Requester(user_id, None, False, False, None, None) + our_user = create_requester(user_id) room_creator = self.homeserver.get_room_creation_handler() self.room_id = self.get_success( room_creator.create_room( diff --git a/tests/test_types.py b/tests/test_types.py
index 480bea1bdc..d4a722a30f 100644 --- a/tests/test_types.py +++ b/tests/test_types.py
@@ -12,9 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from six import string_types from synapse.api.errors import SynapseError -from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart +from synapse.types import ( + GroupID, + RoomAlias, + UserID, + map_username_to_mxid_localpart, + strip_invalid_mxid_characters, +) from tests import unittest @@ -103,3 +110,16 @@ class MapUsernameTestCase(unittest.TestCase): self.assertEqual( map_username_to_mxid_localpart("têst".encode("utf-8")), "t=c3=aast" ) + + +class StripInvalidMxidCharactersTestCase(unittest.TestCase): + def test_return_type(self): + unstripped = strip_invalid_mxid_characters("test") + stripped = strip_invalid_mxid_characters("test@") + + self.assertTrue(isinstance(unstripped, string_types), type(unstripped)) + self.assertTrue(isinstance(stripped, string_types), type(stripped)) + + def test_strip(self): + stripped = strip_invalid_mxid_characters("test@") + self.assertEqual(stripped, "test", stripped) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index e93aa84405..c3c4a93e1f 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py
@@ -50,7 +50,7 @@ async def inject_member_event( sender=sender, state_key=target, content=content, - **kwargs + **kwargs, ) diff --git a/tests/unittest.py b/tests/unittest.py
index 040b126a27..257f465897 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -44,7 +44,7 @@ from synapse.logging.context import ( set_current_context, ) from synapse.server import HomeServer -from synapse.types import Requester, UserID, create_requester +from synapse.types import UserID, create_requester from synapse.util.ratelimitutils import FederationRateLimiter from tests.server import ( @@ -627,7 +627,7 @@ class HomeserverTestCase(TestCase): """ event_creator = self.hs.get_event_creation_handler() secrets = self.hs.get_secrets() - requester = Requester(user, None, False, False, None, None) + requester = create_requester(user) event, context = self.get_success( event_creator.create_event( diff --git a/tests/utils.py b/tests/utils.py
index acec74e9e9..6e7a6fd3cf 100644 --- a/tests/utils.py +++ b/tests/utils.py
@@ -176,6 +176,8 @@ def default_config(name, parse=False): "update_user_directory": False, "caches": {"global_factor": 1}, "listeners": [{"port": 0, "type": "http"}], + # Enable encryption by default in private rooms + "encryption_enabled_by_default_for_room_type": "invite", } if parse: