summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py6
-rw-r--r--tests/app/test_frontend_proxy.py9
-rw-r--r--tests/app/test_openid_listener.py15
-rw-r--r--tests/federation/test_complexity.py2
-rw-r--r--tests/federation/test_federation_server.py3
-rw-r--r--tests/federation/transport/test_server.py2
-rw-r--r--tests/handlers/test_auth.py6
-rw-r--r--tests/handlers/test_directory.py6
-rw-r--r--tests/handlers/test_identity.py116
-rw-r--r--tests/handlers/test_message.py1
-rw-r--r--tests/handlers/test_oidc.py224
-rw-r--r--tests/handlers/test_password_providers.py580
-rw-r--r--tests/handlers/test_profile.py10
-rw-r--r--tests/handlers/test_register.py107
-rw-r--r--tests/handlers/test_saml.py196
-rw-r--r--tests/handlers/test_stats.py8
-rw-r--r--tests/handlers/test_sync.py14
-rw-r--r--tests/handlers/test_typing.py1
-rw-r--r--tests/handlers/test_user_directory.py133
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py2
-rw-r--r--tests/http/test_additional_resource.py11
-rw-r--r--tests/module_api/test_api.py11
-rw-r--r--tests/push/test_http.py175
-rw-r--r--tests/replication/_base.py11
-rw-r--r--tests/replication/test_client_reader_shard.py33
-rw-r--r--tests/replication/test_multi_media_repo.py10
-rw-r--r--tests/replication/test_sharded_event_persister.py50
-rw-r--r--tests/rest/admin/test_admin.py41
-rw-r--r--tests/rest/admin/test_device.py35
-rw-r--r--tests/rest/admin/test_event_reports.py27
-rw-r--r--tests/rest/admin/test_media.py43
-rw-r--r--tests/rest/admin/test_room.py50
-rw-r--r--tests/rest/admin/test_statistics.py27
-rw-r--r--tests/rest/admin/test_user.py434
-rw-r--r--tests/rest/client/test_consent.py32
-rw-r--r--tests/rest/client/test_ephemeral_message.py1
-rw-r--r--tests/rest/client/test_identity.py132
-rw-r--r--tests/rest/client/test_redactions.py2
-rw-r--r--tests/rest/client/test_retention.py3
-rw-r--r--tests/rest/client/test_room_access_rules.py1070
-rw-r--r--tests/rest/client/test_shadow_banned.py8
-rw-r--r--tests/rest/client/test_third_party_rules.py5
-rw-r--r--tests/rest/client/v1/test_directory.py4
-rw-r--r--tests/rest/client/v1/test_events.py4
-rw-r--r--tests/rest/client/v1/test_login.py48
-rw-r--r--tests/rest/client/v1/test_presence.py17
-rw-r--r--tests/rest/client/v1/test_profile.py7
-rw-r--r--tests/rest/client/v1/test_push_rule_attrs.py33
-rw-r--r--tests/rest/client/v1/test_rooms.py107
-rw-r--r--tests/rest/client/v1/test_typing.py4
-rw-r--r--tests/rest/client/v1/utils.py81
-rw-r--r--tests/rest/client/v2_alpha/test_account.py37
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py39
-rw-r--r--tests/rest/client/v2_alpha/test_capabilities.py4
-rw-r--r--tests/rest/client/v2_alpha/test_filter.py7
-rw-r--r--tests/rest/client/v2_alpha/test_password_policy.py8
-rw-r--r--tests/rest/client/v2_alpha/test_register.py294
-rw-r--r--tests/rest/client/v2_alpha/test_relations.py18
-rw-r--r--tests/rest/client/v2_alpha/test_shared_rooms.py1
-rw-r--r--tests/rest/client/v2_alpha/test_sync.py18
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py8
-rw-r--r--tests/rest/media/v1/test_media_storage.py24
-rw-r--r--tests/rest/media/v1/test_url_preview.py93
-rw-r--r--tests/rest/test_health.py7
-rw-r--r--tests/rest/test_well_known.py8
-rw-r--r--tests/rulecheck/__init__.py14
-rw-r--r--tests/rulecheck/test_domainrulecheck.py333
-rw-r--r--tests/server.py90
-rw-r--r--tests/server_notices/test_consent.py3
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py3
-rw-r--r--tests/storage/test_cleanup_extrems.py30
-rw-r--r--tests/storage/test_client_ips.py17
-rw-r--r--tests/storage/test_main.py2
-rw-r--r--tests/storage/test_profile.py4
-rw-r--r--tests/test_mau.py2
-rw-r--r--tests/test_server.py43
-rw-r--r--tests/test_state.py1
-rw-r--r--tests/test_terms_auth.py3
-rw-r--r--tests/test_types.py22
-rw-r--r--tests/unittest.py58
-rw-r--r--tests/utils.py4
81 files changed, 4071 insertions, 1081 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py

index 0fd55f428a..ee5217b074 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py
@@ -282,7 +282,11 @@ class AuthTestCase(unittest.TestCase): ) ) self.store.add_access_token_to_user.assert_called_with( - USER_ID, token, "DEVICE", None + user_id=USER_ID, + token=token, + device_id="DEVICE", + valid_until_ms=None, + puppets_user_id=None, ) def get_user(tok): diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index 4a301b84e1..40abe9d72d 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py
@@ -15,6 +15,7 @@ from synapse.app.generic_worker import GenericWorkerServer +from tests.server import make_request from tests.unittest import HomeserverTestCase @@ -55,10 +56,8 @@ class FrontendProxyTests(HomeserverTestCase): # Grab the resource from the site that was told to listen self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] - self.resource = site.resource.children[b"_matrix"].children[b"client"] - request, channel = self.make_request("PUT", "presence/a/status") - self.render(request) + _, channel = make_request(self.reactor, site, "PUT", "presence/a/status") # 400 + unrecognised, because nothing is registered self.assertEqual(channel.code, 400) @@ -77,10 +76,8 @@ class FrontendProxyTests(HomeserverTestCase): # Grab the resource from the site that was told to listen self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] - self.resource = site.resource.children[b"_matrix"].children[b"client"] - request, channel = self.make_request("PUT", "presence/a/status") - self.render(request) + _, channel = make_request(self.reactor, site, "PUT", "presence/a/status") # 401, because the stub servlet still checks authentication self.assertEqual(channel.code, 401) diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index c2b10d2c70..ea3be95cf1 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py
@@ -20,6 +20,7 @@ from synapse.app.generic_worker import GenericWorkerServer from synapse.app.homeserver import SynapseHomeServer from synapse.config.server import parse_listener_def +from tests.server import make_request from tests.unittest import HomeserverTestCase @@ -66,16 +67,15 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] try: - self.resource = site.resource.children[b"_matrix"].children[b"federation"] + site.resource.children[b"_matrix"].children[b"federation"] except KeyError: if expectation == "no_resource": return raise - request, channel = self.make_request( - "GET", "/_matrix/federation/v1/openid/userinfo" + _, channel = make_request( + self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo" ) - self.render(request) self.assertEqual(channel.code, 401) @@ -115,15 +115,14 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] try: - self.resource = site.resource.children[b"_matrix"].children[b"federation"] + site.resource.children[b"_matrix"].children[b"federation"] except KeyError: if expectation == "no_resource": return raise - request, channel = self.make_request( - "GET", "/_matrix/federation/v1/openid/userinfo" + _, channel = make_request( + self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo" ) - self.render(request) self.assertEqual(channel.code, 401) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 1471cc1a28..0187f56e21 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py
@@ -51,7 +51,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): request, channel = self.make_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.render(request) self.assertEquals(200, channel.code) complexity = channel.json_body["v1"] self.assertTrue(complexity > 0, complexity) @@ -64,7 +63,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): request, channel = self.make_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.render(request) self.assertEquals(200, channel.code) complexity = channel.json_body["v1"] self.assertEqual(complexity, 1.23) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index da933ecd75..3009fbb6c4 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py
@@ -51,7 +51,6 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase): "/_matrix/federation/v1/get_missing_events/%s" % (room_1,), query_content, ) - self.render(request) self.assertEquals(400, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON") @@ -99,7 +98,6 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): request, channel = self.make_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) - self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertEqual( @@ -132,7 +130,6 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): request, channel = self.make_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) - self.render(request) self.assertEquals(403, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 72e22d655f..f9e3c7a51f 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py
@@ -40,7 +40,6 @@ class RoomDirectoryFederationTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/_matrix/federation/v1/publicRooms" ) - self.render(request) self.assertEquals(403, channel.code) @override_config({"allow_public_rooms_over_federation": True}) @@ -48,5 +47,4 @@ class RoomDirectoryFederationTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/_matrix/federation/v1/publicRooms" ) - self.render(request) self.assertEquals(200, channel.code) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index b5055e018c..e24ce81284 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py
@@ -52,7 +52,7 @@ class AuthTestCase(unittest.TestCase): self.fail("some_user was not in %s" % macaroon.inspect()) def test_macaroon_caveats(self): - self.hs.clock.now = 5000 + self.hs.get_clock().now = 5000 token = self.macaroon_generator.generate_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) @@ -78,7 +78,7 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_short_term_login_token_gives_user_id(self): - self.hs.clock.now = 1000 + self.hs.get_clock().now = 1000 token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) user_id = yield defer.ensureDeferred( @@ -87,7 +87,7 @@ class AuthTestCase(unittest.TestCase): self.assertEqual("a_user", user_id) # when we advance the clock, the token should be rejected - self.hs.clock.now = 6000 + self.hs.get_clock().now = 6000 with self.assertRaises(synapse.api.errors.AuthError): yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id(token) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 2ce6dc9528..ee6ef5e6fa 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py
@@ -412,7 +412,6 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): b"directory/room/%23test%3Atest", ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), ) - self.render(request) self.assertEquals(403, channel.code, channel.result) def test_allowed(self): @@ -423,7 +422,6 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): b"directory/room/%23unofficial_test%3Atest", ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), ) - self.render(request) self.assertEquals(200, channel.code, channel.result) @@ -438,7 +436,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) - self.render(request) self.assertEquals(200, channel.code, channel.result) self.room_list_handler = hs.get_room_list_handler() @@ -452,7 +449,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): # Room list is enabled so we should get some results request, channel = self.make_request("GET", b"publicRooms") - self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) > 0) @@ -461,7 +457,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): # Room list disabled so we should get no results request, channel = self.make_request("GET", b"publicRooms") - self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) == 0) @@ -470,5 +465,4 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) - self.render(request) self.assertEquals(403, channel.code, channel.result) diff --git a/tests/handlers/test_identity.py b/tests/handlers/test_identity.py new file mode 100644
index 0000000000..b7d340bcb8 --- /dev/null +++ b/tests/handlers/test_identity.py
@@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from twisted.internet import defer + +import synapse.rest.admin +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import account + +from tests import unittest + + +class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.address = "test@test" + self.is_server_name = "testis" + self.is_server_url = "https://testis" + self.rewritten_is_url = "https://int.testis" + + config = self.default_config() + config["trusted_third_party_id_servers"] = [self.is_server_name] + config["rewrite_identity_server_urls"] = { + self.is_server_url: self.rewritten_is_url + } + + mock_http_client = Mock(spec=["get_json", "post_json_get_json"]) + mock_http_client.get_json.side_effect = defer.succeed({}) + mock_http_client.post_json_get_json.return_value = defer.succeed( + {"address": self.address, "medium": "email"} + ) + + self.hs = self.setup_test_homeserver( + config=config, simple_http_client=mock_http_client + ) + + mock_blacklisting_http_client = Mock(spec=["get_json", "post_json_get_json"]) + mock_blacklisting_http_client.get_json.side_effect = defer.succeed({}) + mock_blacklisting_http_client.post_json_get_json.return_value = defer.succeed( + {"address": self.address, "medium": "email"} + ) + + # TODO: This class does not use a singleton to get it's http client + # This should be fixed for easier testing + # https://github.com/matrix-org/synapse-dinsic/issues/26 + self.hs.get_identity_handler().blacklisting_http_client = ( + mock_blacklisting_http_client + ) + + return self.hs + + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("kermit", "monkey") + + def test_rewritten_id_server(self): + """ + Tests that, when validating a 3PID association while rewriting the IS's server + name: + * the bind request is done against the rewritten hostname + * the original, non-rewritten, server name is stored in the database + """ + handler = self.hs.get_identity_handler() + post_json_get_json = handler.blacklisting_http_client.post_json_get_json + store = self.hs.get_datastore() + + creds = {"sid": "123", "client_secret": "some_secret"} + + # Make sure processing the mocked response goes through. + data = self.get_success( + handler.bind_threepid( + client_secret=creds["client_secret"], + sid=creds["sid"], + mxid=self.user_id, + id_server=self.is_server_name, + use_v2=False, + ) + ) + self.assertEqual(data.get("address"), self.address) + + # Check that the request was done against the rewritten server name. + post_json_get_json.assert_called_once_with( + "%s/_matrix/identity/api/v1/3pid/bind" % (self.rewritten_is_url,), + { + "sid": creds["sid"], + "client_secret": creds["client_secret"], + "mxid": self.user_id, + }, + headers={}, + ) + + # Check that the original server name is saved in the database instead of the + # rewritten one. + id_servers = self.get_success( + store.get_id_servers_user_bound(self.user_id, "email", self.address) + ) + self.assertEqual(id_servers, [self.is_server_name]) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 8b57081cbe..af42775815 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py
@@ -209,5 +209,4 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", path, content={}, access_token=self.access_token ) - self.render(request) self.assertEqual(int(channel.result["code"]), 403) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 0d51705849..a308c46da9 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py
@@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json from urllib.parse import parse_qs, urlparse @@ -24,12 +23,8 @@ import pymacaroons from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone -from synapse.handlers.oidc_handler import ( - MappingException, - OidcError, - OidcHandler, - OidcMappingProvider, -) +from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider +from synapse.handlers.sso import MappingException from synapse.types import UserID from tests.unittest import HomeserverTestCase, override_config @@ -94,6 +89,14 @@ class TestMappingProviderExtra(TestMappingProvider): return {"phone": userinfo["phone"]} +class TestMappingProviderFailures(TestMappingProvider): + async def map_user_attributes(self, userinfo, token, failures): + return { + "localpart": userinfo["username"] + (str(failures) if failures else ""), + "display_name": None, + } + + def simple_async_mock(return_value=None, raises=None): # AsyncMock is not available in python3.5, this mimics part of its behaviour async def cb(*args, **kwargs): @@ -124,22 +127,16 @@ async def get_json(url): class OidcHandlerTestCase(HomeserverTestCase): - def make_homeserver(self, reactor, clock): - - self.http_client = Mock(spec=["get_json"]) - self.http_client.get_json.side_effect = get_json - self.http_client.user_agent = "Synapse Test" - - config = self.default_config() + def default_config(self): + config = super().default_config() config["public_baseurl"] = BASE_URL - oidc_config = {} - oidc_config["enabled"] = True - oidc_config["client_id"] = CLIENT_ID - oidc_config["client_secret"] = CLIENT_SECRET - oidc_config["issuer"] = ISSUER - oidc_config["scopes"] = SCOPES - oidc_config["user_mapping_provider"] = { - "module": __name__ + ".TestMappingProvider", + oidc_config = { + "enabled": True, + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "issuer": ISSUER, + "scopes": SCOPES, + "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"}, } # Update this config with what's in the default config so that @@ -147,13 +144,24 @@ class OidcHandlerTestCase(HomeserverTestCase): oidc_config.update(config.get("oidc_config", {})) config["oidc_config"] = oidc_config - hs = self.setup_test_homeserver( - http_client=self.http_client, - proxied_http_client=self.http_client, - config=config, - ) + return config - self.handler = OidcHandler(hs) + def make_homeserver(self, reactor, clock): + + self.http_client = Mock(spec=["get_json"]) + self.http_client.get_json.side_effect = get_json + self.http_client.user_agent = "Synapse Test" + + hs = self.setup_test_homeserver(proxied_http_client=self.http_client) + + self.handler = hs.get_oidc_handler() + sso_handler = hs.get_sso_handler() + # Mock the render error method. + self.render_error = Mock(return_value=None) + sso_handler.render_error = self.render_error + + # Reduce the number of attempts when generating MXIDs. + sso_handler._MAP_USERNAME_RETRIES = 3 return hs @@ -161,12 +169,12 @@ class OidcHandlerTestCase(HomeserverTestCase): return patch.dict(self.handler._provider_metadata, values) def assertRenderedError(self, error, error_description=None): - args = self.handler._render_error.call_args[0] + args = self.render_error.call_args[0] self.assertEqual(args[1], error) if error_description is not None: self.assertEqual(args[2], error_description) # Reset the render_error mock - self.handler._render_error.reset_mock() + self.render_error.reset_mock() def test_config(self): """Basic config correctly sets up the callback URL and client auth correctly.""" @@ -356,7 +364,6 @@ class OidcHandlerTestCase(HomeserverTestCase): def test_callback_error(self): """Errors from the provider returned in the callback are displayed.""" - self.handler._render_error = Mock() request = Mock(args={}) request.args[b"error"] = [b"invalid_client"] self.get_success(self.handler.handle_oidc_callback(request)) @@ -387,7 +394,6 @@ class OidcHandlerTestCase(HomeserverTestCase): "preferred_username": "bar", } user_id = "@foo:domain.org" - self.handler._render_error = Mock(return_value=None) self.handler._exchange_code = simple_async_mock(return_value=token) self.handler._parse_id_token = simple_async_mock(return_value=userinfo) self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) @@ -435,7 +441,7 @@ class OidcHandlerTestCase(HomeserverTestCase): userinfo, token, user_agent, ip_address ) self.handler._fetch_userinfo.assert_not_called() - self.handler._render_error.assert_not_called() + self.render_error.assert_not_called() # Handle mapping errors self.handler._map_userinfo_to_user = simple_async_mock( @@ -469,7 +475,7 @@ class OidcHandlerTestCase(HomeserverTestCase): userinfo, token, user_agent, ip_address ) self.handler._fetch_userinfo.assert_called_once_with(token) - self.handler._render_error.assert_not_called() + self.render_error.assert_not_called() # Handle userinfo fetching error self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) @@ -485,7 +491,6 @@ class OidcHandlerTestCase(HomeserverTestCase): def test_callback_session(self): """The callback verifies the session presence and validity""" - self.handler._render_error = Mock(return_value=None) request = Mock(spec=["args", "getCookie", "addCookie"]) # Missing cookie @@ -699,19 +704,131 @@ class OidcHandlerTestCase(HomeserverTestCase): ), MappingException, ) - self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken") + self.assertEqual( + str(e.value), "Mapping provider does not support de-duplicating Matrix IDs", + ) @override_config({"oidc_config": {"allow_existing_users": True}}) def test_map_userinfo_to_existing_user(self): """Existing users can log in with OpenID Connect when allow_existing_users is True.""" store = self.hs.get_datastore() - user4 = UserID.from_string("@test_user_4:test") + user = UserID.from_string("@test_user:test") + self.get_success( + store.register_user(user_id=user.to_string(), password_hash=None) + ) + + # Map a user via SSO. + userinfo = { + "sub": "test", + "username": "test_user", + } + token = {} + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user:test") + + # Subsequent calls should map to the same mxid. + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user:test") + + # Note that a second SSO user can be mapped to the same Matrix ID. (This + # requires a unique sub, but something that maps to the same matrix ID, + # in this case we'll just use the same username. A more realistic example + # would be subs which are email addresses, and mapping from the localpart + # of the email, e.g. bob@foo.com and bob@bar.com -> @bob:test.) + userinfo = { + "sub": "test1", + "username": "test_user", + } + token = {} + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user:test") + + # Register some non-exact matching cases. + user2 = UserID.from_string("@TEST_user_2:test") + self.get_success( + store.register_user(user_id=user2.to_string(), password_hash=None) + ) + user2_caps = UserID.from_string("@test_USER_2:test") + self.get_success( + store.register_user(user_id=user2_caps.to_string(), password_hash=None) + ) + + # Attempting to login without matching a name exactly is an error. + userinfo = { + "sub": "test2", + "username": "TEST_USER_2", + } + e = self.get_failure( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ), + MappingException, + ) + self.assertTrue( + str(e.value).startswith( + "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:" + ) + ) + + # Logging in when matching a name exactly should work. + user2 = UserID.from_string("@TEST_USER_2:test") + self.get_success( + store.register_user(user_id=user2.to_string(), password_hash=None) + ) + + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@TEST_USER_2:test") + + def test_map_userinfo_to_invalid_localpart(self): + """If the mapping provider generates an invalid localpart it should be rejected.""" + userinfo = { + "sub": "test2", + "username": "föö", + } + token = {} + + e = self.get_failure( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ), + MappingException, + ) + self.assertEqual(str(e.value), "localpart is invalid: föö") + + @override_config( + { + "oidc_config": { + "user_mapping_provider": { + "module": __name__ + ".TestMappingProviderFailures" + } + } + } + ) + def test_map_userinfo_to_user_retries(self): + """The mapping provider can retry generating an MXID if the MXID is already in use.""" + store = self.hs.get_datastore() self.get_success( - store.register_user(user_id=user4.to_string(), password_hash=None) + store.register_user(user_id="@test_user:test", password_hash=None) ) userinfo = { - "sub": "test4", - "username": "test_user_4", + "sub": "test", + "username": "test_user", } token = {} mxid = self.get_success( @@ -719,4 +836,29 @@ class OidcHandlerTestCase(HomeserverTestCase): userinfo, token, "user-agent", "10.10.10.10" ) ) - self.assertEqual(mxid, "@test_user_4:test") + # test_user is already taken, so test_user1 gets registered instead. + self.assertEqual(mxid, "@test_user1:test") + + # Register all of the potential mxids for a particular OIDC username. + self.get_success( + store.register_user(user_id="@tester:test", password_hash=None) + ) + for i in range(1, 3): + self.get_success( + store.register_user(user_id="@tester%d:test" % i, password_hash=None) + ) + + # Now attempt to map to a username, this will fail since all potential usernames are taken. + userinfo = { + "sub": "tester", + "username": "tester", + } + e = self.get_failure( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ), + MappingException, + ) + self.assertEqual( + str(e.value), "Unable to generate a Matrix ID from the SSO response" + ) diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py new file mode 100644
index 0000000000..ceaf0902d2 --- /dev/null +++ b/tests/handlers/test_password_providers.py
@@ -0,0 +1,580 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the password_auth_provider interface""" + +from typing import Any, Type, Union + +from mock import Mock + +from twisted.internet import defer + +import synapse +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import devices +from synapse.types import JsonDict + +from tests import unittest +from tests.server import FakeChannel +from tests.unittest import override_config + +# (possibly experimental) login flows we expect to appear in the list after the normal +# ones +ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] + +# a mock instance which the dummy auth providers delegate to, so we can see what's going +# on +mock_password_provider = Mock() + + +class PasswordOnlyAuthProvider: + """A password_provider which only implements `check_password`.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, account_handler): + pass + + def check_password(self, *args): + return mock_password_provider.check_password(*args) + + +class CustomAuthProvider: + """A password_provider which implements a custom login type.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, account_handler): + pass + + def get_supported_login_types(self): + return {"test.login_type": ["test_field"]} + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + +class PasswordCustomAuthProvider: + """A password_provider which implements password login via `check_auth`, as well + as a custom type.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, account_handler): + pass + + def get_supported_login_types(self): + return {"m.login.password": ["password"], "test.login_type": ["test_field"]} + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + +def providers_config(*providers: Type[Any]) -> dict: + """Returns a config dict that will enable the given password auth providers""" + return { + "password_providers": [ + {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} + for provider in providers + ] + } + + +class PasswordAuthProviderTests(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + devices.register_servlets, + ] + + def setUp(self): + # we use a global mock device, so make sure we are starting with a clean slate + mock_password_provider.reset_mock() + super().setUp() + + @override_config(providers_config(PasswordOnlyAuthProvider)) + def test_password_only_auth_provider_login(self): + # login flows should only have m.login.password + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) + + # check_password must return an awaitable + mock_password_provider.check_password.return_value = defer.succeed(True) + channel = self._send_password_login("u", "p") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@u:test", channel.json_body["user_id"]) + mock_password_provider.check_password.assert_called_once_with("@u:test", "p") + mock_password_provider.reset_mock() + + # login with mxid should work too + channel = self._send_password_login("@u:bz", "p") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@u:bz", channel.json_body["user_id"]) + mock_password_provider.check_password.assert_called_once_with("@u:bz", "p") + mock_password_provider.reset_mock() + + # try a weird username / pass. Honestly it's unclear what we *expect* to happen + # in these cases, but at least we can guard against the API changing + # unexpectedly + channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"]) + mock_password_provider.check_password.assert_called_once_with( + "@ USER🙂NAME :test", " pASS😢word " + ) + + @override_config(providers_config(PasswordOnlyAuthProvider)) + def test_password_only_auth_provider_ui_auth(self): + """UI Auth should delegate correctly to the password provider""" + + # create the user, otherwise access doesn't work + module_api = self.hs.get_module_api() + self.get_success(module_api.register_user("u")) + + # log in twice, to get two devices + mock_password_provider.check_password.return_value = defer.succeed(True) + tok1 = self.login("u", "p") + self.login("u", "p", device_id="dev2") + mock_password_provider.reset_mock() + + # have the auth provider deny the request to start with + mock_password_provider.check_password.return_value = defer.succeed(False) + + # make the initial request which returns a 401 + session = self._start_delete_device_session(tok1, "dev2") + mock_password_provider.check_password.assert_not_called() + + # Make another request providing the UI auth flow. + channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") + self.assertEqual(channel.code, 401) # XXX why not a 403? + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_password.assert_called_once_with("@u:test", "p") + mock_password_provider.reset_mock() + + # Finally, check the request goes through when we allow it + mock_password_provider.check_password.return_value = defer.succeed(True) + channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") + self.assertEqual(channel.code, 200) + mock_password_provider.check_password.assert_called_once_with("@u:test", "p") + + @override_config(providers_config(PasswordOnlyAuthProvider)) + def test_local_user_fallback_login(self): + """rejected login should fall back to local db""" + self.register_user("localuser", "localpass") + + # check_password must return an awaitable + mock_password_provider.check_password.return_value = defer.succeed(False) + channel = self._send_password_login("u", "p") + self.assertEqual(channel.code, 403, channel.result) + + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@localuser:test", channel.json_body["user_id"]) + + @override_config(providers_config(PasswordOnlyAuthProvider)) + def test_local_user_fallback_ui_auth(self): + """rejected login should fall back to local db""" + self.register_user("localuser", "localpass") + + # have the auth provider deny the request + mock_password_provider.check_password.return_value = defer.succeed(False) + + # log in twice, to get two devices + tok1 = self.login("localuser", "localpass") + self.login("localuser", "localpass", device_id="dev2") + mock_password_provider.check_password.reset_mock() + + # first delete should give a 401 + session = self._start_delete_device_session(tok1, "dev2") + mock_password_provider.check_password.assert_not_called() + + # Wrong password + channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx") + self.assertEqual(channel.code, 401) # XXX why not a 403? + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_password.assert_called_once_with( + "@localuser:test", "xxx" + ) + mock_password_provider.reset_mock() + + # Right password + channel = self._authed_delete_device( + tok1, "dev2", session, "localuser", "localpass" + ) + self.assertEqual(channel.code, 200) + mock_password_provider.check_password.assert_called_once_with( + "@localuser:test", "localpass" + ) + + @override_config( + { + **providers_config(PasswordOnlyAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_no_local_user_fallback_login(self): + """localdb_enabled can block login with the local password + """ + self.register_user("localuser", "localpass") + + # check_password must return an awaitable + mock_password_provider.check_password.return_value = defer.succeed(False) + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_password.assert_called_once_with( + "@localuser:test", "localpass" + ) + + @override_config( + { + **providers_config(PasswordOnlyAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_no_local_user_fallback_ui_auth(self): + """localdb_enabled can block ui auth with the local password + """ + self.register_user("localuser", "localpass") + + # allow login via the auth provider + mock_password_provider.check_password.return_value = defer.succeed(True) + + # log in twice, to get two devices + tok1 = self.login("localuser", "p") + self.login("localuser", "p", device_id="dev2") + mock_password_provider.check_password.reset_mock() + + # first delete should give a 401 + channel = self._delete_device(tok1, "dev2") + self.assertEqual(channel.code, 401) + # m.login.password UIA is permitted because the auth provider allows it, + # even though the localdb does not. + self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}]) + session = channel.json_body["session"] + mock_password_provider.check_password.assert_not_called() + + # now try deleting with the local password + mock_password_provider.check_password.return_value = defer.succeed(False) + channel = self._authed_delete_device( + tok1, "dev2", session, "localuser", "localpass" + ) + self.assertEqual(channel.code, 401) # XXX why not a 403? + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_password.assert_called_once_with( + "@localuser:test", "localpass" + ) + + @override_config( + { + **providers_config(PasswordOnlyAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_auth_disabled(self): + """password auth doesn't work if it's disabled across the board""" + # login flows should be empty + flows = self._get_login_flows() + self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS) + + # login shouldn't work and should be rejected with a 400 ("unknown login type") + channel = self._send_password_login("u", "p") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_password.assert_not_called() + + @override_config(providers_config(CustomAuthProvider)) + def test_custom_auth_provider_login(self): + # login flows should have the custom flow and m.login.password, since we + # haven't disabled local password lookup. + # (password must come first, because reasons) + flows = self._get_login_flows() + self.assertEqual( + flows, + [{"type": "m.login.password"}, {"type": "test.login_type"}] + + ADDITIONAL_LOGIN_FLOWS, + ) + + # login with missing param should be rejected + channel = self._send_login("test.login_type", "u") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_auth.assert_not_called() + + mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + channel = self._send_login("test.login_type", "u", test_field="y") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@user:bz", channel.json_body["user_id"]) + mock_password_provider.check_auth.assert_called_once_with( + "u", "test.login_type", {"test_field": "y"} + ) + mock_password_provider.reset_mock() + + # try a weird username. Again, it's unclear what we *expect* to happen + # in these cases, but at least we can guard against the API changing + # unexpectedly + mock_password_provider.check_auth.return_value = defer.succeed( + "@ MALFORMED! :bz" + ) + channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"]) + mock_password_provider.check_auth.assert_called_once_with( + " USER🙂NAME ", "test.login_type", {"test_field": " abc "} + ) + + @override_config(providers_config(CustomAuthProvider)) + def test_custom_auth_provider_ui_auth(self): + # register the user and log in twice, to get two devices + self.register_user("localuser", "localpass") + tok1 = self.login("localuser", "localpass") + self.login("localuser", "localpass", device_id="dev2") + + # make the initial request which returns a 401 + channel = self._delete_device(tok1, "dev2") + self.assertEqual(channel.code, 401) + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"]) + session = channel.json_body["session"] + + # missing param + body = { + "auth": { + "type": "test.login_type", + "identifier": {"type": "m.id.user", "user": "localuser"}, + "session": session, + }, + } + + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 400) + # there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should + # use it... + self.assertIn("Missing parameters", channel.json_body["error"]) + mock_password_provider.check_auth.assert_not_called() + mock_password_provider.reset_mock() + + # right params, but authing as the wrong user + mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + body["auth"]["test_field"] = "foo" + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_auth.assert_called_once_with( + "localuser", "test.login_type", {"test_field": "foo"} + ) + mock_password_provider.reset_mock() + + # and finally, succeed + mock_password_provider.check_auth.return_value = defer.succeed( + "@localuser:test" + ) + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 200) + mock_password_provider.check_auth.assert_called_once_with( + "localuser", "test.login_type", {"test_field": "foo"} + ) + + @override_config(providers_config(CustomAuthProvider)) + def test_custom_auth_provider_callback(self): + callback = Mock(return_value=defer.succeed(None)) + + mock_password_provider.check_auth.return_value = defer.succeed( + ("@user:bz", callback) + ) + channel = self._send_login("test.login_type", "u", test_field="y") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@user:bz", channel.json_body["user_id"]) + mock_password_provider.check_auth.assert_called_once_with( + "u", "test.login_type", {"test_field": "y"} + ) + + # check the args to the callback + callback.assert_called_once() + call_args, call_kwargs = callback.call_args + # should be one positional arg + self.assertEqual(len(call_args), 1) + self.assertEqual(call_args[0]["user_id"], "@user:bz") + for p in ["user_id", "access_token", "device_id", "home_server"]: + self.assertIn(p, call_args[0]) + + @override_config( + {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} + ) + def test_custom_auth_password_disabled(self): + """Test login with a custom auth provider where password login is disabled""" + self.register_user("localuser", "localpass") + + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + + # login shouldn't work and should be rejected with a 400 ("unknown login type") + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_auth.assert_not_called() + + @override_config( + { + **providers_config(PasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_login(self): + """log in with a custom auth provider which implements password, but password + login is disabled""" + self.register_user("localuser", "localpass") + + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + + # login shouldn't work and should be rejected with a 400 ("unknown login type") + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_auth.assert_not_called() + + @override_config( + { + **providers_config(PasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_ui_auth(self): + """UI Auth with a custom auth provider which implements password, but password + login is disabled""" + # register the user and log in twice via the test login type to get two devices, + self.register_user("localuser", "localpass") + mock_password_provider.check_auth.return_value = defer.succeed( + "@localuser:test" + ) + channel = self._send_login("test.login_type", "localuser", test_field="") + self.assertEqual(channel.code, 200, channel.result) + tok1 = channel.json_body["access_token"] + + channel = self._send_login( + "test.login_type", "localuser", test_field="", device_id="dev2" + ) + self.assertEqual(channel.code, 200, channel.result) + + # make the initial request which returns a 401 + channel = self._delete_device(tok1, "dev2") + self.assertEqual(channel.code, 401) + # Ensure that flows are what is expected. In particular, "password" should *not* + # be present. + self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"]) + session = channel.json_body["session"] + + mock_password_provider.reset_mock() + + # check that auth with password is rejected + body = { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "localuser"}, + "password": "localpass", + "session": session, + }, + } + + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 400) + self.assertEqual( + "Password login has been disabled.", channel.json_body["error"] + ) + mock_password_provider.check_auth.assert_not_called() + mock_password_provider.reset_mock() + + # successful auth + body["auth"]["type"] = "test.login_type" + body["auth"]["test_field"] = "x" + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 200) + mock_password_provider.check_auth.assert_called_once_with( + "localuser", "test.login_type", {"test_field": "x"} + ) + + @override_config( + { + **providers_config(CustomAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_custom_auth_no_local_user_fallback(self): + """Test login with a custom auth provider where the local db is disabled""" + self.register_user("localuser", "localpass") + + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + + # password login shouldn't work and should be rejected with a 400 + # ("unknown login type") + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 400, channel.result) + + def _get_login_flows(self) -> JsonDict: + _, channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body["flows"] + + def _send_password_login(self, user: str, password: str) -> FakeChannel: + return self._send_login(type="m.login.password", user=user, password=password) + + def _send_login(self, type, user, **params) -> FakeChannel: + params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type}) + _, channel = self.make_request("POST", "/_matrix/client/r0/login", params) + return channel + + def _start_delete_device_session(self, access_token, device_id) -> str: + """Make an initial delete device request, and return the UI Auth session ID""" + channel = self._delete_device(access_token, device_id) + self.assertEqual(channel.code, 401) + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + return channel.json_body["session"] + + def _authed_delete_device( + self, + access_token: str, + device_id: str, + session: str, + user_id: str, + password: str, + ) -> FakeChannel: + """Make a delete device request, authenticating with the given uid/password""" + return self._delete_device( + access_token, + device_id, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": user_id}, + "password": password, + "session": session, + }, + }, + ) + + def _delete_device( + self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"", + ) -> FakeChannel: + """Delete an individual device.""" + _, channel = self.make_request( + "DELETE", "devices/" + device, body, access_token=access_token + ) + return channel diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index a69fa28b41..1999bcbb38 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py
@@ -65,7 +65,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_name(self): yield defer.ensureDeferred( - self.store.set_profile_displayname(self.frank.localpart, "Frank") + self.store.set_profile_displayname(self.frank.localpart, "Frank", 1) ) displayname = yield defer.ensureDeferred( @@ -113,7 +113,7 @@ class ProfileTestCase(unittest.TestCase): # Setting displayname for the first time is allowed yield defer.ensureDeferred( - self.store.set_profile_displayname(self.frank.localpart, "Frank") + self.store.set_profile_displayname(self.frank.localpart, "Frank", 1) ) self.assertEquals( @@ -166,7 +166,7 @@ class ProfileTestCase(unittest.TestCase): def test_incoming_fed_query(self): yield defer.ensureDeferred(self.store.create_profile("caroline")) yield defer.ensureDeferred( - self.store.set_profile_displayname("caroline", "Caroline") + self.store.set_profile_displayname("caroline", "Caroline", 1) ) response = yield defer.ensureDeferred( @@ -181,7 +181,7 @@ class ProfileTestCase(unittest.TestCase): def test_get_my_avatar(self): yield defer.ensureDeferred( self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + self.frank.localpart, "http://my.server/me.png", 1 ) ) avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank)) @@ -232,7 +232,7 @@ class ProfileTestCase(unittest.TestCase): # Setting displayname for the first time is allowed yield defer.ensureDeferred( self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + self.frank.localpart, "http://my.server/me.png", 1 ) ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index bdf3d0a8a2..d979aadcd6 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py
@@ -18,9 +18,15 @@ from mock import Mock from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import Codes, ResourceLimitError, SynapseError +from synapse.http.site import SynapseRequest +from synapse.rest.client.v2_alpha.register import ( + _map_email_to_displayname, + register_servlets, +) from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, UserID, create_requester +from tests.server import FakeChannel from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import mock_getRawHeaders @@ -31,6 +37,10 @@ from .. import unittest class RegistrationTestCase(unittest.HomeserverTestCase): """ Tests the RegistrationHandler. """ + servlets = [ + register_servlets, + ] + def make_homeserver(self, reactor, clock): hs_config = self.default_config() @@ -517,6 +527,103 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertTrue(requester.shadow_banned) + def test_email_to_displayname_mapping(self): + """Test that custom emails are mapped to new user displaynames correctly""" + self._check_mapping( + "jack-phillips.rivers@big-org.com", "Jack-Phillips Rivers [Big-Org]" + ) + + self._check_mapping("bob.jones@matrix.org", "Bob Jones [Tchap Admin]") + + self._check_mapping("bob-jones.blabla@gouv.fr", "Bob-Jones Blabla [Gouv]") + + # Multibyte unicode characters + self._check_mapping( + "j\u030a\u0065an-poppy.seed@example.com", + "J\u030a\u0065an-Poppy Seed [Example]", + ) + + def _check_mapping(self, i, expected): + result = _map_email_to_displayname(i) + self.assertEqual(result, expected) + + @override_config( + { + "bind_new_user_emails_to_sydent": "https://is.example.com", + "registrations_require_3pid": ["email"], + "account_threepid_delegates": {}, + "email": { + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + }, + "public_baseurl": "http://localhost", + } + ) + def test_user_email_bound_via_sydent_internal_api(self): + """Tests that emails are bound after registration if this option is set""" + # Register user with an email address + email = "alice@example.com" + + # Mock Synapse's threepid validator + get_threepid_validation_session = Mock( + return_value=make_awaitable( + {"medium": "email", "address": email, "validated_at": 0} + ) + ) + self.store.get_threepid_validation_session = get_threepid_validation_session + delete_threepid_session = Mock(return_value=make_awaitable(None)) + self.store.delete_threepid_session = delete_threepid_session + + # Mock Synapse's http json post method to check for the internal bind call + post_json_get_json = Mock(return_value=make_awaitable(None)) + self.hs.get_simple_http_client().post_json_get_json = post_json_get_json + + # Retrieve a UIA session ID + channel = self.uia_register( + 401, {"username": "alice", "password": "nobodywillguessthis"} + ) + session_id = channel.json_body["session"] + + # Register our email address using the fake validation session above + channel = self.uia_register( + 200, + { + "username": "alice", + "password": "nobodywillguessthis", + "auth": { + "session": session_id, + "type": "m.login.email.identity", + "threepid_creds": {"sid": "blabla", "client_secret": "blablabla"}, + }, + }, + ) + self.assertEqual(channel.json_body["user_id"], "@alice:test") + + # Check that a bind attempt was made to our fake identity server + post_json_get_json.assert_called_with( + "https://is.example.com/_matrix/identity/internal/bind", + {"address": "alice@example.com", "medium": "email", "mxid": "@alice:test"}, + ) + + # Check that we stored a mapping of this bind + bound_threepids = self.get_success( + self.store.user_get_bound_threepids("@alice:test") + ) + self.assertListEqual(bound_threepids, [{"medium": "email", "address": email}]) + + def uia_register(self, expected_response: int, body: dict) -> FakeChannel: + """Make a register request.""" + request, channel = self.make_request( + "POST", "register", body + ) # type: SynapseRequest, FakeChannel + + self.assertEqual(request.code, expected_response) + return channel + async def get_or_create_user( self, requester, localpart, displayname, password_hash=None ): diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py new file mode 100644
index 0000000000..45dc17aba5 --- /dev/null +++ b/tests/handlers/test_saml.py
@@ -0,0 +1,196 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import attr + +from synapse.api.errors import RedirectException +from synapse.handlers.sso import MappingException + +from tests.unittest import HomeserverTestCase, override_config + +# These are a few constants that are used as config parameters in the tests. +BASE_URL = "https://synapse/" + + +@attr.s +class FakeAuthnResponse: + ava = attr.ib(type=dict) + + +class TestMappingProvider: + def __init__(self, config, module): + pass + + @staticmethod + def parse_config(config): + return + + @staticmethod + def get_saml_attributes(config): + return {"uid"}, {"displayName"} + + def get_remote_user_id(self, saml_response, client_redirect_url): + return saml_response.ava["uid"] + + def saml_response_to_user_attributes( + self, saml_response, failures, client_redirect_url + ): + localpart = saml_response.ava["username"] + (str(failures) if failures else "") + return {"mxid_localpart": localpart, "displayname": None} + + +class TestRedirectMappingProvider(TestMappingProvider): + def saml_response_to_user_attributes( + self, saml_response, failures, client_redirect_url + ): + raise RedirectException(b"https://custom-saml-redirect/") + + +class SamlHandlerTestCase(HomeserverTestCase): + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + saml_config = { + "sp_config": {"metadata": {}}, + # Disable grandfathering. + "grandfathered_mxid_source_attribute": None, + "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"}, + } + + # Update this config with what's in the default config so that + # override_config works as expected. + saml_config.update(config.get("saml2_config", {})) + config["saml2_config"] = saml_config + + return config + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + + self.handler = hs.get_saml_handler() + + # Reduce the number of attempts when generating MXIDs. + sso_handler = hs.get_sso_handler() + sso_handler._MAP_USERNAME_RETRIES = 3 + + return hs + + def test_map_saml_response_to_user(self): + """Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" + saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) + # The redirect_url doesn't matter with the default user mapping provider. + redirect_url = "" + mxid = self.get_success( + self.handler._map_saml_response_to_user( + saml_response, redirect_url, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user:test") + + @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) + def test_map_saml_response_to_existing_user(self): + """Existing users can log in with SAML account.""" + store = self.hs.get_datastore() + self.get_success( + store.register_user(user_id="@test_user:test", password_hash=None) + ) + + # Map a user via SSO. + saml_response = FakeAuthnResponse( + {"uid": "tester", "mxid": ["test_user"], "username": "test_user"} + ) + redirect_url = "" + mxid = self.get_success( + self.handler._map_saml_response_to_user( + saml_response, redirect_url, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user:test") + + # Subsequent calls should map to the same mxid. + mxid = self.get_success( + self.handler._map_saml_response_to_user( + saml_response, redirect_url, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user:test") + + def test_map_saml_response_to_invalid_localpart(self): + """If the mapping provider generates an invalid localpart it should be rejected.""" + saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) + redirect_url = "" + e = self.get_failure( + self.handler._map_saml_response_to_user( + saml_response, redirect_url, "user-agent", "10.10.10.10" + ), + MappingException, + ) + self.assertEqual(str(e.value), "localpart is invalid: föö") + + def test_map_saml_response_to_user_retries(self): + """The mapping provider can retry generating an MXID if the MXID is already in use.""" + store = self.hs.get_datastore() + self.get_success( + store.register_user(user_id="@test_user:test", password_hash=None) + ) + saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) + redirect_url = "" + mxid = self.get_success( + self.handler._map_saml_response_to_user( + saml_response, redirect_url, "user-agent", "10.10.10.10" + ) + ) + # test_user is already taken, so test_user1 gets registered instead. + self.assertEqual(mxid, "@test_user1:test") + + # Register all of the potential mxids for a particular SAML username. + self.get_success( + store.register_user(user_id="@tester:test", password_hash=None) + ) + for i in range(1, 3): + self.get_success( + store.register_user(user_id="@tester%d:test" % i, password_hash=None) + ) + + # Now attempt to map to a username, this will fail since all potential usernames are taken. + saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"}) + e = self.get_failure( + self.handler._map_saml_response_to_user( + saml_response, redirect_url, "user-agent", "10.10.10.10" + ), + MappingException, + ) + self.assertEqual( + str(e.value), "Unable to generate a Matrix ID from the SSO response" + ) + + @override_config( + { + "saml2_config": { + "user_mapping_provider": { + "module": __name__ + ".TestRedirectMappingProvider" + }, + } + } + ) + def test_map_saml_response_redirect(self): + saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) + redirect_url = "" + e = self.get_failure( + self.handler._map_saml_response_to_user( + saml_response, redirect_url, "user-agent", "10.10.10.10" + ), + RedirectException, + ) + self.assertEqual(e.value.location, b"https://custom-saml-redirect/") diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 312c0a0d41..0229f58315 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py
@@ -21,8 +21,14 @@ from tests import unittest # The expected number of state events in a fresh public room. EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM = 5 + # The expected number of state events in a fresh private room. -EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6 +# +# Note: we increase this by 2 on the dinsic branch as we send +# a "im.vector.room.access_rules" state event into new private rooms, +# and an encryption state event as all private rooms are encrypted +# by default +EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 7 class StatsRoomTests(unittest.HomeserverTestCase): diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index e178d7765b..e62586142e 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py
@@ -16,7 +16,7 @@ from synapse.api.errors import Codes, ResourceLimitError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION from synapse.handlers.sync import SyncConfig -from synapse.types import UserID +from synapse.types import UserID, create_requester import tests.unittest import tests.utils @@ -38,6 +38,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): user_id1 = "@user1:test" user_id2 = "@user2:test" sync_config = self._generate_sync_config(user_id1) + requester = create_requester(user_id1) self.reactor.advance(100) # So we get not 0 time self.auth_blocking._limit_usage_by_mau = True @@ -45,21 +46,26 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Check that the happy case does not throw errors self.get_success(self.store.upsert_monthly_active_user(user_id1)) - self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config)) + self.get_success( + self.sync_handler.wait_for_sync_for_user(requester, sync_config) + ) # Test that global lock works self.auth_blocking._hs_disabled = True e = self.get_failure( - self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError + self.sync_handler.wait_for_sync_for_user(requester, sync_config), + ResourceLimitError, ) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.auth_blocking._hs_disabled = False sync_config = self._generate_sync_config(user_id2) + requester = create_requester(user_id2) e = self.get_failure( - self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError + self.sync_handler.wait_for_sync_for_user(requester, sync_config), + ResourceLimitError, ) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 16ff2e22d2..abbdf2d524 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py
@@ -228,7 +228,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ), federation_auth_origin=b"farm", ) - self.render(request) self.assertEqual(channel.code, 200) self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 87be94111f..4f6a912ac4 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py
@@ -19,7 +19,7 @@ from twisted.internet import defer import synapse.rest.admin from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import user_directory +from synapse.rest.client.v2_alpha import account, account_validity, user_directory from synapse.storage.roommember import ProfileInfo from tests import unittest @@ -537,7 +537,6 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", b"user_directory/search", b'{"search_term":"user2"}' ) - self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) > 0) @@ -546,6 +545,134 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", b"user_directory/search", b'{"search_term":"user2"}' ) - self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) == 0) + + +class UserInfoTestCase(unittest.FederatingHomeserverTestCase): + servlets = [ + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + account_validity.register_servlets, + synapse.rest.client.v2_alpha.user_directory.register_servlets, + account.register_servlets, + ] + + def default_config(self): + config = super().default_config() + + # Set accounts to expire after a week + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + } + return config + + def prepare(self, reactor, clock, hs): + super(UserInfoTestCase, self).prepare(reactor, clock, hs) + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() + + def test_user_info(self): + """Test /users/info for local users from the Client-Server API""" + user_one, user_two, user_three, user_three_token = self.setup_test_users() + + # Request info about each user from user_three + request, channel = self.make_request( + "POST", + path="/_matrix/client/unstable/users/info", + content={"user_ids": [user_one, user_two, user_three]}, + access_token=user_three_token, + shorthand=False, + ) + self.assertEquals(200, channel.code, channel.result) + + # Check the state of user_one matches + user_one_info = channel.json_body[user_one] + self.assertTrue(user_one_info["deactivated"]) + self.assertFalse(user_one_info["expired"]) + + # Check the state of user_two matches + user_two_info = channel.json_body[user_two] + self.assertFalse(user_two_info["deactivated"]) + self.assertTrue(user_two_info["expired"]) + + # Check the state of user_three matches + user_three_info = channel.json_body[user_three] + self.assertFalse(user_three_info["deactivated"]) + self.assertFalse(user_three_info["expired"]) + + def test_user_info_federation(self): + """Test that /users/info can be called from the Federation API, and + and that we can query remote users from the Client-Server API + """ + user_one, user_two, user_three, user_three_token = self.setup_test_users() + + # Request information about our local users from the perspective of a remote server + request, channel = self.make_request( + "POST", + path="/_matrix/federation/unstable/users/info", + content={"user_ids": [user_one, user_two, user_three]}, + ) + self.assertEquals(200, channel.code) + + # Check the state of user_one matches + user_one_info = channel.json_body[user_one] + self.assertTrue(user_one_info["deactivated"]) + self.assertFalse(user_one_info["expired"]) + + # Check the state of user_two matches + user_two_info = channel.json_body[user_two] + self.assertFalse(user_two_info["deactivated"]) + self.assertTrue(user_two_info["expired"]) + + # Check the state of user_three matches + user_three_info = channel.json_body[user_three] + self.assertFalse(user_three_info["deactivated"]) + self.assertFalse(user_three_info["expired"]) + + def setup_test_users(self): + """Create an admin user and three test users, each with a different state""" + + # Create an admin user to expire other users with + self.register_user("admin", "adminpassword", admin=True) + admin_token = self.login("admin", "adminpassword") + + # Create three users + user_one = self.register_user("alice", "pass") + user_one_token = self.login("alice", "pass") + user_two = self.register_user("bob", "pass") + user_three = self.register_user("carl", "pass") + user_three_token = self.login("carl", "pass") + + # Deactivate user_one + self.deactivate(user_one, user_one_token) + + # Expire user_two + self.expire(user_two, admin_token) + + # Do nothing to user_three + + return user_one, user_two, user_three, user_three_token + + def expire(self, user_id_to_expire, admin_tok): + url = "/_synapse/admin/v1/account_validity/validity" + request_data = { + "user_id": user_id_to_expire, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request, channel = self.make_request( + "POST", url, request_data, access_token=admin_tok + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def deactivate(self, user_id, tok): + request_data = { + "auth": {"type": "m.login.password", "user": user_id, "password": "pass"}, + "erase": False, + } + request, channel = self.make_request( + "POST", "account/deactivate", request_data, access_token=tok + ) + self.assertEqual(request.code, 200) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 8b5ad4574f..c3f7a28dcc 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -101,7 +101,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.agent = MatrixFederationAgent( reactor=self.reactor, - tls_client_options_factory=self.tls_factory, + tls_client_options_factory=FederationPolicyForHTTPS(config), user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. _srv_resolver=self.mock_resolver, _well_known_resolver=self.well_known_resolver, diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py
index 62d36c2906..05e9c449be 100644 --- a/tests/http/test_additional_resource.py +++ b/tests/http/test_additional_resource.py
@@ -17,6 +17,7 @@ from synapse.http.additional_resource import AdditionalResource from synapse.http.server import respond_with_json +from tests.server import FakeSite, make_request from tests.unittest import HomeserverTestCase @@ -43,20 +44,18 @@ class AdditionalResourceTests(HomeserverTestCase): def test_async(self): handler = _AsyncTestCustomEndpoint({}, None).handle_request - self.resource = AdditionalResource(self.hs, handler) + resource = AdditionalResource(self.hs, handler) - request, channel = self.make_request("GET", "/") - self.render(request) + request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/") self.assertEqual(request.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) def test_sync(self): handler = _SyncTestCustomEndpoint({}, None).handle_request - self.resource = AdditionalResource(self.hs, handler) + resource = AdditionalResource(self.hs, handler) - request, channel = self.make_request("GET", "/") - self.render(request) + request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/") self.assertEqual(request.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 9b573ac24d..27206ca3db 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py
@@ -94,12 +94,13 @@ class ModuleApiTestCase(HomeserverTestCase): self.assertFalse(hasattr(event, "state_key")) self.assertDictEqual(event.content, content) + expected_requester = create_requester( + user_id, authenticated_entity=self.hs.hostname + ) + # Check that the event was sent self.event_creation_handler.create_and_send_nonmember_event.assert_called_with( - create_requester(user_id), - event_dict, - ratelimit=False, - ignore_shadow_ban=True, + expected_requester, event_dict, ratelimit=False, ignore_shadow_ban=True, ) # Create and send a state event @@ -128,7 +129,7 @@ class ModuleApiTestCase(HomeserverTestCase): # Check that the event was sent self.event_creation_handler.create_and_send_nonmember_event.assert_called_with( - create_requester(user_id), + expected_requester, { "type": "m.room.power_levels", "content": content, diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 8571924b29..cd46c485dd 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py
@@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from mock import Mock from twisted.internet.defer import Deferred @@ -20,8 +19,9 @@ from twisted.internet.defer import Deferred import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import receipts -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class HTTPPusherTests(HomeserverTestCase): @@ -29,6 +29,7 @@ class HTTPPusherTests(HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, + receipts.register_servlets, ] user_id = True hijack_auth = False @@ -346,8 +347,8 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual(len(self.push_attempts), 2) self.assertEqual(self.push_attempts[1][1], "example.com") - # check that this is low-priority - self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + # check that this is high-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") def test_sends_high_priority_for_mention(self): """ @@ -418,8 +419,8 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual(len(self.push_attempts), 2) self.assertEqual(self.push_attempts[1][1], "example.com") - # check that this is low-priority - self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + # check that this is high-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") def test_sends_high_priority_for_atroom(self): """ @@ -497,5 +498,163 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual(len(self.push_attempts), 2) self.assertEqual(self.push_attempts[1][1], "example.com") - # check that this is low-priority - self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + # check that this is high-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") + + def test_push_unread_count_group_by_room(self): + """ + The HTTP pusher will group unread count by number of unread rooms. + """ + # Carry out common push count tests and setup + self._test_push_unread_count() + + # Carry out our option-value specific test + # + # This push should still only contain an unread count of 1 (for 1 unread room) + self.assertEqual( + self.push_attempts[5][2]["notification"]["counts"]["unread"], 1 + ) + + @override_config({"push": {"group_unread_count_by_room": False}}) + def test_push_unread_count_message_count(self): + """ + The HTTP pusher will send the total unread message count. + """ + # Carry out common push count tests and setup + self._test_push_unread_count() + + # Carry out our option-value specific test + # + # We're counting every unread message, so there should now be 4 since the + # last read receipt + self.assertEqual( + self.push_attempts[5][2]["notification"]["counts"]["unread"], 4 + ) + + def _test_push_unread_count(self): + """ + Tests that the correct unread count appears in sent push notifications + + Note that: + * Sending messages will cause push notifications to go out to relevant users + * Sending a read receipt will cause a "badge update" notification to go out to + the user that sent the receipt + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("other_user", "pass") + other_access_token = self.login("other_user", "pass") + + # Create a room (as other_user) + room_id = self.helper.create_room_as(other_user_id, tok=other_access_token) + + # The user to get notified joins + self.helper.join(room=room_id, user=user_id, tok=access_token) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple.token_id + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send a message + response = self.helper.send( + room_id, body="Hello there!", tok=other_access_token + ) + # To get an unread count, the user who is getting notified has to have a read + # position in the room. We'll set the read position to this event in a moment + first_message_event_id = response["event_id"] + + # Advance time a bit (so the pusher will register something has happened) and + # make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + + # Check that the unread count for the room is 0 + # + # The unread count is zero as the user has no read receipt in the room yet + self.assertEqual( + self.push_attempts[0][2]["notification"]["counts"]["unread"], 0 + ) + + # Now set the user's read receipt position to the first event + # + # This will actually trigger a new notification to be sent out so that + # even if the user does not receive another message, their unread + # count goes down + request, channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id), + {}, + access_token=access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Advance time and make the push succeed + self.push_attempts[1][0].callback({}) + self.pump() + + # Unread count is still zero as we've read the only message in the room + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual( + self.push_attempts[1][2]["notification"]["counts"]["unread"], 0 + ) + + # Send another message + self.helper.send( + room_id, body="How's the weather today?", tok=other_access_token + ) + + # Advance time and make the push succeed + self.push_attempts[2][0].callback({}) + self.pump() + + # This push should contain an unread count of 1 as there's now been one + # message since our last read receipt + self.assertEqual(len(self.push_attempts), 3) + self.assertEqual( + self.push_attempts[2][2]["notification"]["counts"]["unread"], 1 + ) + + # Since we're grouping by room, sending more messages shouldn't increase the + # unread count, as they're all being sent in the same room + self.helper.send(room_id, body="Hello?", tok=other_access_token) + + # Advance time and make the push succeed + self.pump() + self.push_attempts[3][0].callback({}) + + self.helper.send(room_id, body="Hello??", tok=other_access_token) + + # Advance time and make the push succeed + self.pump() + self.push_attempts[4][0].callback({}) + + self.helper.send(room_id, body="HELLO???", tok=other_access_token) + + # Advance time and make the push succeed + self.pump() + self.push_attempts[5][0].callback({}) + + self.assertEqual(len(self.push_attempts), 6) diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 5c633ac6df..295c5d58a6 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py
@@ -36,7 +36,7 @@ from synapse.server import HomeServer from synapse.util import Clock from tests import unittest -from tests.server import FakeTransport, render +from tests.server import FakeTransport try: import hiredis @@ -78,7 +78,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool self.test_handler = self._build_replication_data_handler() - self.worker_hs.replication_data_handler = self.test_handler + self.worker_hs._replication_data_handler = self.test_handler repl_handler = ReplicationCommandHandler(self.worker_hs) self.client = ClientReplicationStreamProtocol( @@ -240,8 +240,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): lambda: self._handle_http_replication_attempt(self.hs, 8765), ) - def create_test_json_resource(self): - """Overrides `HomeserverTestCase.create_test_json_resource`. + def create_test_resource(self): + """Overrides `HomeserverTestCase.create_test_resource`. """ # We override this so that it automatically registers all the HTTP # replication servlets, without having to explicitly do that in all @@ -347,9 +347,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config["worker_replication_http_port"] = "8765" return config - def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest): - render(request, self._hs_to_site[worker_hs].resource, self.reactor) - def replicate(self): """Tell the master side of replication that something has happened, and then wait for the replication to occur. diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 86c03fd89c..96801db473 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py
@@ -20,7 +20,7 @@ from synapse.rest.client.v2_alpha import register from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker -from tests.server import FakeChannel +from tests.server import FakeChannel, make_request logger = logging.getLogger(__name__) @@ -46,23 +46,28 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): """Test that registration works when using a single client reader worker. """ worker_hs = self.make_worker_hs("synapse.app.client_reader") + site = self._hs_to_site[worker_hs] - request_1, channel_1 = self.make_request( + request_1, channel_1 = make_request( + self.reactor, + site, "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, ) # type: SynapseRequest, FakeChannel - self.render_on_worker(worker_hs, request_1) self.assertEqual(request_1.code, 401) # Grab the session session = channel_1.json_body["session"] # also complete the dummy auth - request_2, channel_2 = self.make_request( - "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} + request_2, channel_2 = make_request( + self.reactor, + site, + "POST", + "register", + {"auth": {"session": session, "type": "m.login.dummy"}}, ) # type: SynapseRequest, FakeChannel - self.render_on_worker(worker_hs, request_2) self.assertEqual(request_2.code, 200) # We're given a registered user. @@ -74,22 +79,28 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): worker_hs_1 = self.make_worker_hs("synapse.app.client_reader") worker_hs_2 = self.make_worker_hs("synapse.app.client_reader") - request_1, channel_1 = self.make_request( + site_1 = self._hs_to_site[worker_hs_1] + request_1, channel_1 = make_request( + self.reactor, + site_1, "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, ) # type: SynapseRequest, FakeChannel - self.render_on_worker(worker_hs_1, request_1) self.assertEqual(request_1.code, 401) # Grab the session session = channel_1.json_body["session"] # also complete the dummy auth - request_2, channel_2 = self.make_request( - "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} + site_2 = self._hs_to_site[worker_hs_2] + request_2, channel_2 = make_request( + self.reactor, + site_2, + "POST", + "register", + {"auth": {"session": session, "type": "m.login.dummy"}}, ) # type: SynapseRequest, FakeChannel - self.render_on_worker(worker_hs_2, request_2) self.assertEqual(request_2.code, 200) # We're given a registered user. diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 77c261dbf7..48b574ccbe 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py
@@ -28,7 +28,7 @@ from synapse.server import HomeServer from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.replication._base import BaseMultiWorkerStreamTestCase -from tests.server import FakeChannel, FakeTransport +from tests.server import FakeChannel, FakeSite, FakeTransport, make_request logger = logging.getLogger(__name__) @@ -67,14 +67,16 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): The channel for the *client* request and the *outbound* request for the media which the caller should respond to. """ - - request, channel = self.make_request( + resource = hs.get_media_repository_resource().children[b"download"] + _, channel = make_request( + self.reactor, + FakeSite(resource), "GET", "/{}/{}".format(target, media_id), shorthand=False, access_token=self.access_token, + await_result=False, ) - request.render(hs.get_media_repository_resource().children[b"download"]) self.pump() clients = self.reactor.tcpClients diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 82cf033d4e..77fc3856d5 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py
@@ -22,6 +22,7 @@ from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import sync from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import make_request from tests.utils import USE_POSTGRES_FOR_TESTS logger = logging.getLogger(__name__) @@ -148,6 +149,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): sync_hs = self.make_worker_hs( "synapse.app.generic_worker", {"worker_name": "sync"}, ) + sync_hs_site = self._hs_to_site[sync_hs] # Specially selected room IDs that get persisted on different workers. room_id1 = "!foo:test" @@ -178,8 +180,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ) # Do an initial sync so that we're up to date. - request, channel = self.make_request("GET", "/sync", access_token=access_token) - self.render_on_worker(sync_hs, request) + request, channel = make_request( + self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token + ) next_batch = channel.json_body["next_batch"] # We now gut wrench into the events stream MultiWriterIdGenerator on @@ -203,10 +206,13 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # Check that syncing still gets the new event, despite the gap in the # stream IDs. - request, channel = self.make_request( - "GET", "/sync?since={}".format(next_batch), access_token=access_token + request, channel = make_request( + self.reactor, + sync_hs_site, + "GET", + "/sync?since={}".format(next_batch), + access_token=access_token, ) - self.render_on_worker(sync_hs, request) # We should only see the new event and nothing else self.assertIn(room_id1, channel.json_body["rooms"]["join"]) @@ -230,12 +236,13 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token) first_event_in_room2 = response["event_id"] - request, channel = self.make_request( + request, channel = make_request( + self.reactor, + sync_hs_site, "GET", "/sync?since={}".format(vector_clock_token), access_token=access_token, ) - self.render_on_worker(sync_hs, request) self.assertNotIn(room_id1, channel.json_body["rooms"]["join"]) self.assertIn(room_id2, channel.json_body["rooms"]["join"]) @@ -254,10 +261,13 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token) self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token) - request, channel = self.make_request( - "GET", "/sync?since={}".format(next_batch), access_token=access_token + request, channel = make_request( + self.reactor, + sync_hs_site, + "GET", + "/sync?since={}".format(next_batch), + access_token=access_token, ) - self.render_on_worker(sync_hs, request) prev_batch1 = channel.json_body["rooms"]["join"][room_id1]["timeline"][ "prev_batch" @@ -269,50 +279,54 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # Paginating back in the first room should not produce any results, as # no events have happened in it. This tests that we are correctly # filtering results based on the vector clock portion. - request, channel = self.make_request( + request, channel = make_request( + self.reactor, + sync_hs_site, "GET", "/rooms/{}/messages?from={}&to={}&dir=b".format( room_id1, prev_batch1, vector_clock_token ), access_token=access_token, ) - self.render_on_worker(sync_hs, request) self.assertListEqual([], channel.json_body["chunk"]) # Paginating back on the second room should produce the first event # again. This tests that pagination isn't completely broken. - request, channel = self.make_request( + request, channel = make_request( + self.reactor, + sync_hs_site, "GET", "/rooms/{}/messages?from={}&to={}&dir=b".format( room_id2, prev_batch2, vector_clock_token ), access_token=access_token, ) - self.render_on_worker(sync_hs, request) self.assertEqual(len(channel.json_body["chunk"]), 1) self.assertEqual( channel.json_body["chunk"][0]["event_id"], first_event_in_room2 ) # Paginating forwards should give the same results - request, channel = self.make_request( + request, channel = make_request( + self.reactor, + sync_hs_site, "GET", "/rooms/{}/messages?from={}&to={}&dir=f".format( room_id1, vector_clock_token, prev_batch1 ), access_token=access_token, ) - self.render_on_worker(sync_hs, request) self.assertListEqual([], channel.json_body["chunk"]) - request, channel = self.make_request( + request, channel = make_request( + self.reactor, + sync_hs_site, "GET", "/rooms/{}/messages?from={}&to={}&dir=f".format( room_id2, vector_clock_token, prev_batch2, ), access_token=access_token, ) - self.render_on_worker(sync_hs, request) self.assertEqual(len(channel.json_body["chunk"]), 1) self.assertEqual( channel.json_body["chunk"][0]["event_id"], first_event_in_room2 diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 0f1144fe1e..4f76f8f768 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py
@@ -30,19 +30,19 @@ from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import groups from tests import unittest +from tests.server import FakeSite, make_request class VersionTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/server_version" - def create_test_json_resource(self): + def create_test_resource(self): resource = JsonResource(self.hs) VersionServlet(self.hs).register(resource) return resource def test_version_string(self): request, channel = self.make_request("GET", self.url, shorthand=False) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -75,7 +75,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): content={"localpart": "test"}, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) group_id = channel.json_body["group_id"] @@ -88,14 +87,12 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={} ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) url = "/groups/%s/self/accept_invite" % (group_id,) request, channel = self.make_request( "PUT", url.encode("ascii"), access_token=self.other_user_token, content={} ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Check other user knows they're in the group @@ -103,7 +100,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): self.assertIn(group_id, self._get_groups_user_is_in(self.other_user_token)) # Now delete the group - url = "/admin/delete_group/" + group_id + url = "/_synapse/admin/v1/delete_group/" + group_id request, channel = self.make_request( "POST", url.encode("ascii"), @@ -111,7 +108,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): content={"localpart": "test"}, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Check group returns 404 @@ -131,7 +127,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): "GET", url.encode("ascii"), access_token=self.admin_user_tok ) - self.render(request) self.assertEqual( expect_code, int(channel.result["code"]), msg=channel.result["body"] ) @@ -143,7 +138,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): "GET", "/joined_groups".encode("ascii"), access_token=access_token ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) return channel.json_body["groups"] @@ -222,11 +216,14 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): def _ensure_quarantined(self, admin_user_tok, server_and_media_id): """Ensure a piece of media is quarantined when trying to access it.""" - request, channel = self.make_request( - "GET", server_and_media_id, shorthand=False, access_token=admin_user_tok, + request, channel = make_request( + self.reactor, + FakeSite(self.download_resource), + "GET", + server_and_media_id, + shorthand=False, + access_token=admin_user_tok, ) - request.render(self.download_resource) - self.pump(1.0) # Should be quarantined self.assertEqual( @@ -247,7 +244,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", url.encode("ascii"), access_token=non_admin_user_tok, ) - self.render(request) # Expect a forbidden error self.assertEqual( @@ -261,7 +257,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", url.encode("ascii"), access_token=non_admin_user_tok, ) - self.render(request) # Expect a forbidden error self.assertEqual( @@ -287,14 +282,14 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): server_name, media_id = server_name_and_media_id.split("/") # Attempt to access the media - request, channel = self.make_request( + request, channel = make_request( + self.reactor, + FakeSite(self.download_resource), "GET", server_name_and_media_id, shorthand=False, access_token=non_admin_user_tok, ) - request.render(self.download_resource) - self.pump(1.0) # Should be successful self.assertEqual(200, int(channel.code), msg=channel.result["body"]) @@ -305,7 +300,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): urllib.parse.quote(media_id), ) request, channel = self.make_request("POST", url, access_token=admin_user_tok,) - self.render(request) self.pump(1.0) self.assertEqual(200, int(channel.code), msg=channel.result["body"]) @@ -358,7 +352,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): room_id ) request, channel = self.make_request("POST", url, access_token=admin_user_tok,) - self.render(request) self.pump(1.0) self.assertEqual(200, int(channel.code), msg=channel.result["body"]) self.assertEqual( @@ -405,7 +398,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", url.encode("ascii"), access_token=admin_user_tok, ) - self.render(request) self.pump(1.0) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -448,7 +440,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", url.encode("ascii"), access_token=admin_user_tok, ) - self.render(request) self.pump(1.0) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -462,14 +453,14 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self._ensure_quarantined(admin_user_tok, server_and_media_id_1) # Attempt to access each piece of media - request, channel = self.make_request( + request, channel = make_request( + self.reactor, + FakeSite(self.download_resource), "GET", server_and_media_id_2, shorthand=False, access_token=non_admin_user_tok, ) - request.render(self.download_resource) - self.pump(1.0) # Shouldn't be quarantined self.assertEqual( diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index d89eb90cfe..cf3a007598 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py
@@ -51,19 +51,16 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): Try to get a device of an user without authentication. """ request, channel = self.make_request("GET", self.url, b"{}") - self.render(request) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) request, channel = self.make_request("PUT", self.url, b"{}") - self.render(request) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) request, channel = self.make_request("DELETE", self.url, b"{}") - self.render(request) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -75,7 +72,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.other_user_token, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -83,7 +79,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", self.url, access_token=self.other_user_token, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -91,7 +86,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "DELETE", self.url, access_token=self.other_user_token, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -108,7 +102,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -116,7 +109,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -124,7 +116,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "DELETE", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -141,7 +132,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -149,7 +139,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -157,7 +146,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "DELETE", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -173,7 +161,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -181,14 +168,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) request, channel = self.make_request( "DELETE", url, access_token=self.admin_user_tok, ) - self.render(request) # Delete unknown device returns status 200 self.assertEqual(200, channel.code, msg=channel.json_body) @@ -218,7 +203,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"]) @@ -227,7 +211,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) @@ -247,7 +230,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -255,7 +237,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) @@ -272,7 +253,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -280,7 +260,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new displayname", channel.json_body["display_name"]) @@ -292,7 +271,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) @@ -316,7 +294,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "DELETE", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -347,7 +324,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): Try to list devices of an user without authentication. """ request, channel = self.make_request("GET", self.url, b"{}") - self.render(request) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -361,7 +337,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=other_user_token, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -374,7 +349,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -388,7 +362,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -403,7 +376,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -422,7 +394,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_devices, channel.json_body["total"]) @@ -461,7 +432,6 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): Try to delete devices of an user without authentication. """ request, channel = self.make_request("POST", self.url, b"{}") - self.render(request) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -475,7 +445,6 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", self.url, access_token=other_user_token, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -488,7 +457,6 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -502,7 +470,6 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -518,7 +485,6 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) # Delete unknown devices returns status 200 self.assertEqual(200, channel.code, msg=channel.json_body) @@ -550,7 +516,6 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 303622217f..11b72c10f7 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py
@@ -75,7 +75,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Try to get an event report without authentication. """ request, channel = self.make_request("GET", self.url, b"{}") - self.render(request) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -88,7 +87,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.other_user_tok, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -101,7 +99,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -117,7 +114,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?limit=5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -133,7 +129,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?from=5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -149,7 +144,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -167,7 +161,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.url + "?room_id=%s" % self.room_id1, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 10) @@ -188,7 +181,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.url + "?user_id=%s" % self.other_user, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 10) @@ -209,7 +201,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 5) @@ -230,7 +221,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?dir=b", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -247,7 +237,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?dir=f", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -268,7 +257,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?dir=bar", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -282,7 +270,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -295,7 +282,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?from=-5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -310,7 +296,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?limit=20", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -322,7 +307,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?limit=21", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -334,7 +318,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?limit=19", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -347,7 +330,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?from=19", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -366,7 +348,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase): json.dumps({"score": -100, "reason": "this makes me sad"}), access_token=user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) def _check_fields(self, content): @@ -419,7 +400,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): Try to get event report without authentication. """ request, channel = self.make_request("GET", self.url, b"{}") - self.render(request) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -432,7 +412,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.other_user_tok, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -445,7 +424,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self._check_fields(channel.json_body) @@ -461,7 +439,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/event_reports/-123", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -476,7 +453,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/event_reports/abcdef", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -491,7 +467,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/event_reports/", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -510,7 +485,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/event_reports/123", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -528,7 +502,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): json.dumps({"score": -100, "reason": "this makes me sad"}), access_token=user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) def _check_fields(self, content): diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 721fa1ed51..dadf9db660 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py
@@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, profile, room from synapse.rest.media.v1.filepath import MediaFilePaths from tests import unittest +from tests.server import FakeSite, make_request class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): @@ -50,7 +51,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") request, channel = self.make_request("DELETE", url, b"{}") - self.render(request) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -67,7 +67,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "DELETE", url, access_token=self.other_user_token, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -81,7 +80,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "DELETE", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -95,7 +93,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "DELETE", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) @@ -124,14 +121,14 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.assertEqual(server_name, self.server_name) # Attempt to access media - request, channel = self.make_request( + request, channel = make_request( + self.reactor, + FakeSite(download_resource), "GET", server_and_media_id, shorthand=False, access_token=self.admin_user_tok, ) - request.render(download_resource) - self.pump(1.0) # Should be successful self.assertEqual( @@ -152,7 +149,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "DELETE", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) @@ -161,14 +157,14 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): ) # Attempt to access media - request, channel = self.make_request( + request, channel = make_request( + self.reactor, + FakeSite(download_resource), "GET", server_and_media_id, shorthand=False, access_token=self.admin_user_tok, ) - request.render(download_resource) - self.pump(1.0) self.assertEqual( 404, channel.code, @@ -196,7 +192,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.handler = hs.get_device_handler() self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname - self.clock = hs.clock self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -210,7 +205,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): """ request, channel = self.make_request("POST", self.url, b"{}") - self.render(request) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -225,7 +219,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", self.url, access_token=self.other_user_token, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -239,7 +232,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", url + "?before_ts=1234", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) @@ -251,7 +243,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) @@ -266,7 +257,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -280,7 +270,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=1234&size_gt=-1234", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -294,7 +283,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=1234&keep_profiles=not_bool", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) @@ -325,7 +313,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( @@ -350,7 +337,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -363,7 +349,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( @@ -387,7 +372,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&size_gt=67", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -399,7 +383,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&size_gt=66", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( @@ -424,7 +407,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() @@ -433,7 +415,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -445,7 +426,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( @@ -471,7 +451,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() @@ -480,7 +459,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -492,7 +470,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( @@ -535,14 +512,14 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): media_id = server_and_media_id.split("/")[1] local_path = self.filepaths.local_media_filepath(media_id) - request, channel = self.make_request( + request, channel = make_request( + self.reactor, + FakeSite(download_resource), "GET", server_and_media_id, shorthand=False, access_token=self.admin_user_tok, ) - request.render(download_resource) - self.pump(1.0) if expect_success: self.assertEqual( diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 535d68f284..46933a0493 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py
@@ -78,14 +78,13 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): ) # Test that the admin can still send shutdown - url = "admin/shutdown_room/" + room_id + url = "/_synapse/admin/v1/shutdown_room/" + room_id request, channel = self.make_request( "POST", url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -110,18 +109,16 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): json.dumps({"history_visibility": "world_readable"}), access_token=self.other_user_token, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Test that the admin can still send shutdown - url = "admin/shutdown_room/" + room_id + url = "/_synapse/admin/v1/shutdown_room/" + room_id request, channel = self.make_request( "POST", url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -136,7 +133,6 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) - self.render(request) self.assertEqual( expect_code, int(channel.result["code"]), msg=channel.result["body"] ) @@ -145,7 +141,6 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) - self.render(request) self.assertEqual( expect_code, int(channel.result["code"]), msg=channel.result["body"] ) @@ -192,7 +187,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", self.url, json.dumps({}), access_token=self.other_user_tok, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -206,7 +200,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", url, json.dumps({}), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -220,7 +213,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", url, json.dumps({}), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -239,7 +231,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertIn("new_room_id", channel.json_body) @@ -259,7 +250,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -278,7 +268,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) @@ -295,7 +284,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) @@ -322,7 +310,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(None, channel.json_body["new_room_id"]) @@ -356,7 +343,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(None, channel.json_body["new_room_id"]) @@ -391,7 +377,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(None, channel.json_body["new_room_id"]) @@ -439,7 +424,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) @@ -470,7 +454,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): json.dumps({"history_visibility": "world_readable"}), access_token=self.other_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Test that room is not purged @@ -488,7 +471,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) @@ -551,7 +533,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) - self.render(request) self.assertEqual( expect_code, int(channel.result["code"]), msg=channel.result["body"] ) @@ -560,7 +541,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) - self.render(request) self.assertEqual( expect_code, int(channel.result["code"]), msg=channel.result["body"] ) @@ -595,7 +575,6 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -647,7 +626,6 @@ class RoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) # Check request completed successfully self.assertEqual(200, int(channel.code), msg=channel.json_body) @@ -729,7 +707,6 @@ class RoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual( 200, int(channel.result["code"]), msg=channel.result["body"] ) @@ -770,7 +747,6 @@ class RoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) def test_correct_room_attributes(self): @@ -794,7 +770,6 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Set this new alias as the canonical alias for this room @@ -822,7 +797,6 @@ class RoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Check that rooms were returned @@ -867,7 +841,6 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=admin_user_tok, ) - self.render(request) self.assertEqual( 200, int(channel.result["code"]), msg=channel.result["body"] ) @@ -905,7 +878,6 @@ class RoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) # Check that rooms were returned @@ -1042,7 +1014,6 @@ class RoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) if expected_http_code != 200: @@ -1104,7 +1075,6 @@ class RoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("room_id", channel.json_body) @@ -1152,7 +1122,6 @@ class RoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertCountEqual( @@ -1164,7 +1133,6 @@ class RoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertCountEqual( @@ -1208,7 +1176,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.second_tok, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1225,7 +1192,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) @@ -1242,7 +1208,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -1259,7 +1224,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -1280,7 +1244,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("No known servers", channel.json_body["error"]) @@ -1298,7 +1261,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -1318,7 +1280,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(self.public_room_id, channel.json_body["room_id"]) @@ -1328,7 +1289,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.render(request) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) @@ -1349,7 +1309,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1377,7 +1336,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) - self.render(request) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) @@ -1392,7 +1350,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["room_id"]) @@ -1401,7 +1358,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.render(request) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) @@ -1422,7 +1378,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["room_id"]) @@ -1432,7 +1387,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.render(request) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 816683a612..907b49f889 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py
@@ -47,7 +47,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): Try to list users without authentication. """ request, channel = self.make_request("GET", self.url, b"{}") - self.render(request) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -59,7 +58,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, json.dumps({}), access_token=self.other_user_tok, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -72,7 +70,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?order_by=bar", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -81,7 +78,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?from=-5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -90,7 +86,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -99,7 +94,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?from_ts=-1234", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -108,7 +102,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?until_ts=-1234", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -119,7 +112,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?from_ts=10&until_ts=5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -128,7 +120,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?search_term=", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -137,7 +128,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?dir=bar", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -151,7 +141,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?limit=5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 10) @@ -168,7 +157,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?from=5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -185,7 +173,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -206,7 +193,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?limit=20", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_users) @@ -218,7 +204,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?limit=21", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_users) @@ -230,7 +215,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?limit=19", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_users) @@ -242,7 +226,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?from=19", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_users) @@ -258,7 +241,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -337,7 +319,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) @@ -346,7 +327,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 0) @@ -362,7 +342,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) @@ -370,7 +349,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["media_count"], 6) @@ -381,7 +359,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -391,7 +368,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=foo_user_1", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 11) @@ -401,7 +377,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=bar_user_10", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10") self.assertEqual(channel.json_body["total"], 1) @@ -410,7 +385,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?search_term=foobar", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 0) @@ -476,7 +450,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_user_list)) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index d74efede06..54d46f4bd3 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -24,8 +24,8 @@ from mock import Mock import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError -from synapse.rest.client.v1 import login, profile, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client.v1 import login, logout, profile, room +from synapse.rest.client.v2_alpha import devices, sync from tests import unittest from tests.test_utils import make_awaitable @@ -41,7 +41,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - self.url = "/_matrix/client/r0/admin/register" + self.url = "/_synapse/admin/v1/register" self.registration_handler = Mock() self.identity_handler = Mock() @@ -71,7 +71,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.hs.config.registration_shared_secret = None request, channel = self.make_request("POST", self.url, b"{}") - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -89,7 +88,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.hs.get_secrets = Mock(return_value=secrets) request, channel = self.make_request("GET", self.url) - self.render(request) self.assertEqual(channel.json_body, {"nonce": "abcd"}) @@ -99,7 +97,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): only last for SALT_TIMEOUT (60s). """ request, channel = self.make_request("GET", self.url) - self.render(request) nonce = channel.json_body["nonce"] # 59 seconds @@ -107,7 +104,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = json.dumps({"nonce": nonce}) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("username must be specified", channel.json_body["error"]) @@ -116,7 +112,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.reactor.advance(2) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("unrecognised nonce", channel.json_body["error"]) @@ -126,7 +121,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): Only the provided nonce can be used, as it's checked in the MAC. """ request, channel = self.make_request("GET", self.url) - self.render(request) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -143,7 +137,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } ) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("HMAC incorrect", channel.json_body["error"]) @@ -154,7 +147,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): user is registered. """ request, channel = self.make_request("GET", self.url) - self.render(request) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -174,7 +166,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } ) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -184,7 +175,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): A valid unrecognised nonce. """ request, channel = self.make_request("GET", self.url) - self.render(request) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -201,14 +191,12 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } ) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("unrecognised nonce", channel.json_body["error"]) @@ -222,7 +210,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): def nonce(): request, channel = self.make_request("GET", self.url) - self.render(request) return channel.json_body["nonce"] # @@ -232,7 +219,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({}) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("nonce must be specified", channel.json_body["error"]) @@ -244,7 +230,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({"nonce": nonce()}) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("username must be specified", channel.json_body["error"]) @@ -252,7 +237,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be a string body = json.dumps({"nonce": nonce(), "username": 1234}) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid username", channel.json_body["error"]) @@ -260,7 +244,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid username", channel.json_body["error"]) @@ -268,7 +251,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid username", channel.json_body["error"]) @@ -280,7 +262,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({"nonce": nonce(), "username": "a"}) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("password must be specified", channel.json_body["error"]) @@ -288,7 +269,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be a string body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid password", channel.json_body["error"]) @@ -296,7 +276,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid password", channel.json_body["error"]) @@ -304,7 +283,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Super long body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid password", channel.json_body["error"]) @@ -323,7 +301,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } ) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid user type", channel.json_body["error"]) @@ -335,7 +312,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # set no displayname request, channel = self.make_request("GET", self.url) - self.render(request) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -346,19 +322,16 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): {"nonce": nonce, "username": "bob1", "password": "abc123", "mac": want_mac} ) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob1:test", channel.json_body["user_id"]) request, channel = self.make_request("GET", "/profile/@bob1:test/displayname") - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("bob1", channel.json_body["displayname"]) # displayname is None request, channel = self.make_request("GET", self.url) - self.render(request) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -375,19 +348,16 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } ) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob2:test", channel.json_body["user_id"]) request, channel = self.make_request("GET", "/profile/@bob2:test/displayname") - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("bob2", channel.json_body["displayname"]) # displayname is empty request, channel = self.make_request("GET", self.url) - self.render(request) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -404,18 +374,15 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } ) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob3:test", channel.json_body["user_id"]) request, channel = self.make_request("GET", "/profile/@bob3:test/displayname") - self.render(request) self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) # set displayname request, channel = self.make_request("GET", self.url) - self.render(request) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -432,13 +399,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } ) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob4:test", channel.json_body["user_id"]) request, channel = self.make_request("GET", "/profile/@bob4:test/displayname") - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Bob's Name", channel.json_body["displayname"]) @@ -465,7 +430,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Register new user with admin API request, channel = self.make_request("GET", self.url) - self.render(request) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -485,7 +449,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } ) request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -511,7 +474,6 @@ class UsersListTestCase(unittest.HomeserverTestCase): Try to list users without authentication. """ request, channel = self.make_request("GET", self.url, b"{}") - self.render(request) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"]) @@ -526,7 +488,6 @@ class UsersListTestCase(unittest.HomeserverTestCase): b"{}", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(3, len(channel.json_body["users"])) @@ -562,7 +523,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url, access_token=self.other_user_token, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("You are not a server admin", channel.json_body["error"]) @@ -570,7 +530,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", url, access_token=self.other_user_token, content=b"{}", ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("You are not a server admin", channel.json_body["error"]) @@ -585,7 +544,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v2/users/@unknown_person:test", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) @@ -613,7 +571,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -626,7 +583,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -659,7 +615,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -672,7 +627,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -700,7 +654,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/sync", access_token=self.admin_user_tok ) - self.render(request) if channel.code != 200: raise HttpResponseException( @@ -729,7 +682,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -769,7 +721,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) # Admin user is not blocked by mau anymore self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) @@ -807,7 +758,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -852,7 +802,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -879,7 +828,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -897,7 +845,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) @@ -907,7 +854,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) @@ -929,7 +875,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) @@ -940,7 +885,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) @@ -961,7 +905,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) @@ -972,7 +915,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) @@ -990,7 +932,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=json.dumps({"deactivated": True}).encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self._is_erased("@user:test", False) d = self.store.mark_user_erased("@user:test") @@ -1004,7 +945,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=json.dumps({"deactivated": False}).encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) # Reactivate the user. @@ -1016,14 +956,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): encoding="utf_8" ), ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Get user request, channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) @@ -1044,7 +982,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) @@ -1054,7 +991,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) @@ -1076,7 +1012,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -1086,7 +1021,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -1102,7 +1036,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -1110,7 +1043,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -1153,7 +1085,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): Try to list rooms of an user without authentication. """ request, channel = self.make_request("GET", self.url, b"{}") - self.render(request) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -1167,7 +1098,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=other_user_token, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1180,7 +1110,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -1194,7 +1123,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -1208,7 +1136,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1228,7 +1155,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_rooms, channel.json_body["total"]) @@ -1258,7 +1184,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase): Try to list pushers of an user without authentication. """ request, channel = self.make_request("GET", self.url, b"{}") - self.render(request) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -1272,7 +1197,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=other_user_token, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1285,7 +1209,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -1299,7 +1222,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -1313,7 +1235,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1343,7 +1264,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) @@ -1383,7 +1303,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): Try to list media of an user without authentication. """ request, channel = self.make_request("GET", self.url, b"{}") - self.render(request) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -1397,7 +1316,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=other_user_token, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1410,7 +1328,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -1424,7 +1341,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -1441,7 +1357,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?limit=5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_media) @@ -1461,7 +1376,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?from=5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_media) @@ -1481,7 +1395,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_media) @@ -1497,7 +1410,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -1510,7 +1422,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?from=-5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -1529,7 +1440,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?limit=20", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_media) @@ -1541,7 +1451,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?limit=21", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_media) @@ -1553,7 +1462,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?limit=19", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_media) @@ -1566,7 +1474,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url + "?from=19", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], number_media) @@ -1582,7 +1489,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1600,7 +1506,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) @@ -1638,3 +1543,336 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertIn("last_access_ts", m) self.assertIn("quarantined_by", m) self.assertIn("safe_from_quarantine", m) + + +class UserTokenRestTestCase(unittest.HomeserverTestCase): + """Test for /_synapse/admin/v1/users/<user>/login + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + room.register_servlets, + devices.register_servlets, + logout.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + self.url = "/_synapse/admin/v1/users/%s/login" % urllib.parse.quote( + self.other_user + ) + + def _get_token(self) -> str: + request, channel = self.make_request( + "POST", self.url, b"{}", access_token=self.admin_user_tok + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + return channel.json_body["access_token"] + + def test_no_auth(self): + """Try to login as a user without authentication. + """ + request, channel = self.make_request("POST", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_not_admin(self): + """Try to login as a user as a non-admin user. + """ + request, channel = self.make_request( + "POST", self.url, b"{}", access_token=self.other_user_tok + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + + def test_send_event(self): + """Test that sending event as a user works. + """ + # Create a room. + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok) + + # Login in as the user + puppet_token = self._get_token() + + # Test that sending works, and generates the event as the right user. + resp = self.helper.send_event(room_id, "com.example.test", tok=puppet_token) + event_id = resp["event_id"] + event = self.get_success(self.store.get_event(event_id)) + self.assertEqual(event.sender, self.other_user) + + def test_devices(self): + """Tests that logging in as a user doesn't create a new device for them. + """ + # Login in as the user + self._get_token() + + # Check that we don't see a new device in our devices list + request, channel = self.make_request( + "GET", "devices", b"{}", access_token=self.other_user_tok + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # We should only see the one device (from the login in `prepare`) + self.assertEqual(len(channel.json_body["devices"]), 1) + + def test_logout(self): + """Test that calling `/logout` with the token works. + """ + # Login in as the user + puppet_token = self._get_token() + + # Test that we can successfully make a request + request, channel = self.make_request( + "GET", "devices", b"{}", access_token=puppet_token + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Logout with the puppet token + request, channel = self.make_request( + "POST", "logout", b"{}", access_token=puppet_token + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # The puppet token should no longer work + request, channel = self.make_request( + "GET", "devices", b"{}", access_token=puppet_token + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + + # .. but the real user's tokens should still work + request, channel = self.make_request( + "GET", "devices", b"{}", access_token=self.other_user_tok + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def test_user_logout_all(self): + """Tests that the target user calling `/logout/all` does *not* expire + the token. + """ + # Login in as the user + puppet_token = self._get_token() + + # Test that we can successfully make a request + request, channel = self.make_request( + "GET", "devices", b"{}", access_token=puppet_token + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Logout all with the real user token + request, channel = self.make_request( + "POST", "logout/all", b"{}", access_token=self.other_user_tok + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # The puppet token should still work + request, channel = self.make_request( + "GET", "devices", b"{}", access_token=puppet_token + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # .. but the real user's tokens shouldn't + request, channel = self.make_request( + "GET", "devices", b"{}", access_token=self.other_user_tok + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + + def test_admin_logout_all(self): + """Tests that the admin user calling `/logout/all` does expire the + token. + """ + # Login in as the user + puppet_token = self._get_token() + + # Test that we can successfully make a request + request, channel = self.make_request( + "GET", "devices", b"{}", access_token=puppet_token + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Logout all with the admin user token + request, channel = self.make_request( + "POST", "logout/all", b"{}", access_token=self.admin_user_tok + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # The puppet token should no longer work + request, channel = self.make_request( + "GET", "devices", b"{}", access_token=puppet_token + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + + # .. but the real user's tokens should still work + request, channel = self.make_request( + "GET", "devices", b"{}", access_token=self.other_user_tok + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + @unittest.override_config( + { + "public_baseurl": "https://example.org/", + "user_consent": { + "version": "1.0", + "policy_name": "My Cool Privacy Policy", + "template_dir": "/", + "require_at_registration": True, + "block_events_error": "You should accept the policy", + }, + "form_secret": "123secret", + } + ) + def test_consent(self): + """Test that sending a message is not subject to the privacy policies. + """ + # Have the admin user accept the terms. + self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0")) + + # First, cheekily accept the terms and create a room + self.get_success(self.store.user_set_consent_version(self.other_user, "1.0")) + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok) + self.helper.send_event(room_id, "com.example.test", tok=self.other_user_tok) + + # Now unaccept it and check that we can't send an event + self.get_success(self.store.user_set_consent_version(self.other_user, "0.0")) + self.helper.send_event( + room_id, "com.example.test", tok=self.other_user_tok, expect_code=403 + ) + + # Login in as the user + puppet_token = self._get_token() + + # Sending an event on their behalf should work fine + self.helper.send_event(room_id, "com.example.test", tok=puppet_token) + + @override_config( + {"limit_usage_by_mau": True, "max_mau_value": 1, "mau_trial_days": 0} + ) + def test_mau_limit(self): + # Create a room as the admin user. This will bump the monthly active users to 1. + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # Trying to join as the other user should fail due to reaching MAU limit. + self.helper.join( + room_id, user=self.other_user, tok=self.other_user_tok, expect_code=403 + ) + + # Logging in as the other user and joining a room should work, even + # though the MAU limit would stop the user doing so. + puppet_token = self._get_token() + self.helper.join(room_id, user=self.other_user, tok=puppet_token) + + +class WhoisRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.url1 = "/_synapse/admin/v1/whois/%s" % urllib.parse.quote(self.other_user) + self.url2 = "/_matrix/client/r0/admin/whois/%s" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to get information of an user without authentication. + """ + request, channel = self.make_request("GET", self.url1, b"{}") + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + request, channel = self.make_request("GET", self.url2, b"{}") + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_not_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + self.register_user("user2", "pass") + other_user2_token = self.login("user2", "pass") + + request, channel = self.make_request( + "GET", self.url1, access_token=other_user2_token, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + request, channel = self.make_request( + "GET", self.url2, access_token=other_user2_token, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" + url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain" + + request, channel = self.make_request( + "GET", url1, access_token=self.admin_user_tok, + ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only whois a local user", channel.json_body["error"]) + + request, channel = self.make_request( + "GET", url2, access_token=self.admin_user_tok, + ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only whois a local user", channel.json_body["error"]) + + def test_get_whois_admin(self): + """ + The lookup should succeed for an admin. + """ + request, channel = self.make_request( + "GET", self.url1, access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(self.other_user, channel.json_body["user_id"]) + self.assertIn("devices", channel.json_body) + + request, channel = self.make_request( + "GET", self.url2, access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(self.other_user, channel.json_body["user_id"]) + self.assertIn("devices", channel.json_body) + + def test_get_whois_user(self): + """ + The lookup should succeed for a normal user looking up their own information. + """ + other_user_token = self.login("user", "pass") + + request, channel = self.make_request( + "GET", self.url1, access_token=other_user_token, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(self.other_user, channel.json_body["user_id"]) + self.assertIn("devices", channel.json_body) + + request, channel = self.make_request( + "GET", self.url2, access_token=other_user_token, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(self.other_user, channel.json_body["user_id"]) + self.assertIn("devices", channel.json_body) diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index 6803b372ac..e2e6a5e16d 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py
@@ -21,7 +21,7 @@ from synapse.rest.client.v1 import login, room from synapse.rest.consent import consent_resource from tests import unittest -from tests.server import render +from tests.server import FakeSite, make_request class ConsentResourceTestCase(unittest.HomeserverTestCase): @@ -61,8 +61,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): def test_render_public_consent(self): """You can observe the terms form without specifying a user""" resource = consent_resource.ConsentResource(self.hs) - request, channel = self.make_request("GET", "/consent?v=1", shorthand=False) - render(request, resource, self.reactor) + request, channel = make_request( + self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False + ) self.assertEqual(channel.code, 200) def test_accept_consent(self): @@ -81,10 +82,14 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "") + "&u=user" ) - request, channel = self.make_request( - "GET", consent_uri, access_token=access_token, shorthand=False + request, channel = make_request( + self.reactor, + FakeSite(resource), + "GET", + consent_uri, + access_token=access_token, + shorthand=False, ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 200) # Get the version from the body, and whether we've consented @@ -92,21 +97,26 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): self.assertEqual(consented, "False") # POST to the consent page, saying we've agreed - request, channel = self.make_request( + request, channel = make_request( + self.reactor, + FakeSite(resource), "POST", consent_uri + "&v=" + version, access_token=access_token, shorthand=False, ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 200) # Fetch the consent page, to get the consent version -- it should have # changed - request, channel = self.make_request( - "GET", consent_uri, access_token=access_token, shorthand=False + request, channel = make_request( + self.reactor, + FakeSite(resource), + "GET", + consent_uri, + access_token=access_token, + shorthand=False, ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 200) # Get the version from the body, and check that it's the version we diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
index 5e9c07ebf3..a1ccc4ee9a 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py
@@ -94,7 +94,6 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) request, channel = self.make_request("GET", url) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index c973521907..aa094c3543 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py
@@ -15,15 +15,22 @@ import json +from mock import Mock + +from twisted.internet import defer + import synapse.rest.admin from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import account from tests import unittest -class IdentityTestCase(unittest.HomeserverTestCase): +class IdentityDisabledTestCase(unittest.HomeserverTestCase): + """Tests that 3PID lookup attempts fail when the HS's config disallows them.""" servlets = [ + account.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, @@ -32,21 +39,95 @@ class IdentityTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() + config["trusted_third_party_id_servers"] = ["testis"] config["enable_3pid_lookup"] = False self.hs = self.setup_test_homeserver(config=config) return self.hs + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + def test_3pid_invite_disabled(self): + request, channel = self.make_request( + b"POST", "/createRoom", b"{}", access_token=self.tok + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + room_id = channel.json_body["room_id"] + + params = { + "id_server": "testis", + "medium": "email", + "address": "test@example.com", + } + request_data = json.dumps(params) + request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") + request, channel = self.make_request( + b"POST", request_url, request_data, access_token=self.tok + ) + self.assertEquals(channel.result["code"], b"403", channel.result) + def test_3pid_lookup_disabled(self): - self.hs.config.enable_3pid_lookup = False + url = ( + "/_matrix/client/unstable/account/3pid/lookup" + "?id_server=testis&medium=email&address=foo@bar.baz" + ) + request, channel = self.make_request("GET", url, access_token=self.tok) + self.assertEqual(channel.result["code"], b"403", channel.result) + + def test_3pid_bulk_lookup_disabled(self): + url = "/_matrix/client/unstable/account/3pid/bulk_lookup" + data = { + "id_server": "testis", + "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]], + } + request_data = json.dumps(data) + request, channel = self.make_request( + "POST", url, request_data, access_token=self.tok + ) + self.assertEqual(channel.result["code"], b"403", channel.result) + + +class IdentityEnabledTestCase(unittest.HomeserverTestCase): + """Tests that 3PID lookup attempts succeed when the HS's config allows them.""" + + servlets = [ + account.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + config = self.default_config() + config["enable_3pid_lookup"] = True + config["trusted_third_party_id_servers"] = ["testis"] + + mock_http_client = Mock(spec=["get_json", "post_json_get_json"]) + mock_http_client.get_json.return_value = defer.succeed({"mxid": "@f:test"}) + mock_http_client.post_json_get_json.return_value = defer.succeed({}) + + self.hs = self.setup_test_homeserver( + config=config, simple_http_client=mock_http_client + ) + + # TODO: This class does not use a singleton to get it's http client + # This should be fixed for easier testing + # https://github.com/matrix-org/synapse-dinsic/issues/26 + self.hs.get_identity_handler().http_client = mock_http_client + + return self.hs - self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + def test_3pid_invite_enabled(self): request, channel = self.make_request( - b"POST", "/createRoom", b"{}", access_token=tok + b"POST", "/createRoom", b"{}", access_token=self.tok ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) room_id = channel.json_body["room_id"] @@ -58,7 +139,40 @@ class IdentityTestCase(unittest.HomeserverTestCase): request_data = json.dumps(params) request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") request, channel = self.make_request( - b"POST", request_url, request_data, access_token=tok + b"POST", request_url, request_data, access_token=self.tok + ) + + get_json = self.hs.get_identity_handler().http_client.get_json + get_json.assert_called_once_with( + "https://testis/_matrix/identity/api/v1/lookup", + {"address": "test@example.com", "medium": "email"}, + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_3pid_lookup_enabled(self): + url = ( + "/_matrix/client/unstable/account/3pid/lookup" + "?id_server=testis&medium=email&address=foo@bar.baz" + ) + self.make_request("GET", url, access_token=self.tok) + + get_json = self.hs.get_simple_http_client().get_json + get_json.assert_called_once_with( + "https://testis/_matrix/identity/api/v1/lookup", + {"address": "foo@bar.baz", "medium": "email"}, + ) + + def test_3pid_bulk_lookup_enabled(self): + url = "/_matrix/client/unstable/account/3pid/bulk_lookup" + data = { + "id_server": "testis", + "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]], + } + request_data = json.dumps(data) + self.make_request("POST", url, request_data, access_token=self.tok) + + post_json = self.hs.get_simple_http_client().post_json_get_json + post_json.assert_called_once_with( + "https://testis/_matrix/identity/api/v1/bulk_lookup", + {"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]]}, ) - self.render(request) - self.assertEquals(channel.result["code"], b"403", channel.result) diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index d2bcf256fa..c1f516cc93 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py
@@ -72,7 +72,6 @@ class RedactionsTestCase(HomeserverTestCase): request, channel = self.make_request( "POST", path, content={}, access_token=access_token ) - self.render(request) self.assertEqual(int(channel.result["code"]), expect_code) return channel.json_body @@ -80,7 +79,6 @@ class RedactionsTestCase(HomeserverTestCase): request, channel = self.make_request( "GET", "sync", access_token=self.mod_access_token ) - self.render(request) self.assertEqual(channel.result["code"], b"200") room_sync = channel.json_body["rooms"]["join"][room_id] return room_sync["timeline"]["events"] diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 7d3773ff78..62b41ada42 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py
@@ -34,6 +34,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() + config["default_room_version"] = "1" config["retention"] = { "enabled": True, "default_policy": { @@ -243,6 +244,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() + config["default_room_version"] = "1" config["retention"] = { "enabled": True, } @@ -326,7 +328,6 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) request, channel = self.make_request("GET", url, access_token=self.token) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py new file mode 100644
index 0000000000..f02a613a28 --- /dev/null +++ b/tests/rest/client/test_room_access_rules.py
@@ -0,0 +1,1070 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import random +import string +from typing import Optional + +from mock import Mock + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset +from synapse.api.errors import SynapseError +from synapse.rest import admin +from synapse.rest.client.v1 import directory, login, room +from synapse.third_party_rules.access_rules import ( + ACCESS_RULES_TYPE, + AccessRules, + RoomAccessRules, +) +from synapse.types import JsonDict, create_requester + +from tests import unittest + + +class RoomAccessTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + config["third_party_event_rules"] = { + "module": "synapse.third_party_rules.access_rules.RoomAccessRules", + "config": { + "domains_forbidden_when_restricted": ["forbidden_domain"], + "id_server": "testis", + }, + } + config["trusted_third_party_id_servers"] = ["testis"] + + def send_invite(destination, room_id, event_id, pdu): + return defer.succeed(pdu) + + def get_json(uri, args={}, headers=None): + address_domain = args["address"].split("@")[1] + return defer.succeed({"hs": address_domain}) + + def post_json_get_json(uri, post_json, args={}, headers=None): + token = "".join(random.choice(string.ascii_letters) for _ in range(10)) + return defer.succeed( + { + "token": token, + "public_keys": [ + { + "public_key": "serverpublickey", + "key_validity_url": "https://testis/pubkey/isvalid", + }, + { + "public_key": "phemeralpublickey", + "key_validity_url": "https://testis/pubkey/ephemeral/isvalid", + }, + ], + "display_name": "f...@b...", + } + ) + + mock_federation_client = Mock(spec=["send_invite"]) + mock_federation_client.send_invite.side_effect = send_invite + + mock_http_client = Mock(spec=["get_json", "post_json_get_json"],) + # Mocking the response for /info on the IS API. + mock_http_client.get_json.side_effect = get_json + # Mocking the response for /store-invite on the IS API. + mock_http_client.post_json_get_json.side_effect = post_json_get_json + self.hs = self.setup_test_homeserver( + config=config, + federation_client=mock_federation_client, + simple_http_client=mock_http_client, + ) + + # TODO: This class does not use a singleton to get it's http client + # This should be fixed for easier testing + # https://github.com/matrix-org/synapse-dinsic/issues/26 + self.hs.get_identity_handler().blacklisting_http_client = mock_http_client + + self.third_party_event_rules = self.hs.get_third_party_event_rules() + + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + self.restricted_room = self.create_room() + self.unrestricted_room = self.create_room(rule=AccessRules.UNRESTRICTED) + self.direct_rooms = [ + self.create_room(direct=True), + self.create_room(direct=True), + self.create_room(direct=True), + ] + + self.invitee_id = self.register_user("invitee", "test") + self.invitee_tok = self.login("invitee", "test") + + self.helper.invite( + room=self.direct_rooms[0], + src=self.user_id, + targ=self.invitee_id, + tok=self.tok, + ) + + def test_create_room_no_rule(self): + """Tests that creating a room with no rule will set the default.""" + room_id = self.create_room() + rule = self.current_rule_in_room(room_id) + + self.assertEqual(rule, AccessRules.RESTRICTED) + + def test_create_room_direct_no_rule(self): + """Tests that creating a direct room with no rule will set the default.""" + room_id = self.create_room(direct=True) + rule = self.current_rule_in_room(room_id) + + self.assertEqual(rule, AccessRules.DIRECT) + + def test_create_room_valid_rule(self): + """Tests that creating a room with a valid rule will set the right.""" + room_id = self.create_room(rule=AccessRules.UNRESTRICTED) + rule = self.current_rule_in_room(room_id) + + self.assertEqual(rule, AccessRules.UNRESTRICTED) + + def test_create_room_invalid_rule(self): + """Tests that creating a room with an invalid rule will set fail.""" + self.create_room(rule=AccessRules.DIRECT, expected_code=400) + + def test_create_room_direct_invalid_rule(self): + """Tests that creating a direct room with an invalid rule will fail. + """ + self.create_room(direct=True, rule=AccessRules.RESTRICTED, expected_code=400) + + def test_create_room_default_power_level_rules(self): + """Tests that a room created with no power level overrides instead uses the dinum + defaults + """ + room_id = self.create_room(direct=True, rule=AccessRules.DIRECT) + power_levels = self.helper.get_state(room_id, "m.room.power_levels", self.tok) + + # Inviting another user should require PL50, even in private rooms + self.assertEqual(power_levels["invite"], 50) + # Sending arbitrary state events should require PL100 + self.assertEqual(power_levels["state_default"], 100) + + def test_create_room_fails_on_incorrect_power_level_rules(self): + """Tests that a room created with power levels lower than that required are rejected""" + modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id) + modified_power_levels["invite"] = 0 + modified_power_levels["state_default"] = 50 + + self.create_room( + direct=True, + rule=AccessRules.DIRECT, + initial_state=[ + {"type": "m.room.power_levels", "content": modified_power_levels} + ], + expected_code=400, + ) + + def test_create_room_with_missing_power_levels_use_default_values(self): + """ + Tests that a room created with custom power levels, but without defining invite or state_default + succeeds, but the missing values are replaced with the defaults. + """ + + # Attempt to create a room without defining "invite" or "state_default" + modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id) + del modified_power_levels["invite"] + del modified_power_levels["state_default"] + room_id = self.create_room( + direct=True, + rule=AccessRules.DIRECT, + initial_state=[ + {"type": "m.room.power_levels", "content": modified_power_levels} + ], + ) + + # This should succeed, but the defaults should be put in place instead + room_power_levels = self.helper.get_state( + room_id, "m.room.power_levels", self.tok + ) + self.assertEqual(room_power_levels["invite"], 50) + self.assertEqual(room_power_levels["state_default"], 100) + + # And now the same test, but using power_levels_content_override instead + # of initial_state (which takes a slightly different codepath) + modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id) + del modified_power_levels["invite"] + del modified_power_levels["state_default"] + room_id = self.create_room( + direct=True, + rule=AccessRules.DIRECT, + power_levels_content_override=modified_power_levels, + ) + + # This should succeed, but the defaults should be put in place instead + room_power_levels = self.helper.get_state( + room_id, "m.room.power_levels", self.tok + ) + self.assertEqual(room_power_levels["invite"], 50) + self.assertEqual(room_power_levels["state_default"], 100) + + def test_existing_room_can_change_power_levels(self): + """Tests that a room created with default power levels can have their power levels + dropped after room creation + """ + # Creates a room with the default power levels + room_id = self.create_room( + direct=True, rule=AccessRules.DIRECT, expected_code=200, + ) + + # Attempt to drop invite and state_default power levels after the fact + room_power_levels = self.helper.get_state( + room_id, "m.room.power_levels", self.tok + ) + room_power_levels["invite"] = 0 + room_power_levels["state_default"] = 50 + self.helper.send_state( + room_id, "m.room.power_levels", room_power_levels, self.tok + ) + + def test_public_room(self): + """Tests that it's only possible to have a room listed in the public room list + if the access rule is restricted. + """ + # Creating a room with the public_chat preset should succeed and set the access + # rule to restricted. + preset_room_id = self.create_room(preset=RoomCreationPreset.PUBLIC_CHAT) + self.assertEqual( + self.current_rule_in_room(preset_room_id), AccessRules.RESTRICTED + ) + + # Creating a room with the public join rule in its initial state should succeed + # and set the access rule to restricted. + init_state_room_id = self.create_room( + initial_state=[ + { + "type": "m.room.join_rules", + "content": {"join_rule": JoinRules.PUBLIC}, + } + ] + ) + self.assertEqual( + self.current_rule_in_room(init_state_room_id), AccessRules.RESTRICTED + ) + + # List preset_room_id in the public room list + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/directory/list/room/%s" % (preset_room_id,), + {"visibility": "public"}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + # List init_state_room_id in the public room list + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/directory/list/room/%s" % (init_state_room_id,), + {"visibility": "public"}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + # Changing access rule to unrestricted should fail. + self.change_rule_in_room( + preset_room_id, AccessRules.UNRESTRICTED, expected_code=403 + ) + self.change_rule_in_room( + init_state_room_id, AccessRules.UNRESTRICTED, expected_code=403 + ) + + # Changing access rule to direct should fail. + self.change_rule_in_room(preset_room_id, AccessRules.DIRECT, expected_code=403) + self.change_rule_in_room( + init_state_room_id, AccessRules.DIRECT, expected_code=403 + ) + + # Creating a new room with the public_chat preset and an access rule of direct + # should fail. + self.create_room( + preset=RoomCreationPreset.PUBLIC_CHAT, + rule=AccessRules.DIRECT, + expected_code=400, + ) + + # Changing join rule to public in an direct room should fail. + self.change_join_rule_in_room( + self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403 + ) + + def test_restricted(self): + """Tests that in restricted mode we're unable to invite users from blacklisted + servers but can invite other users. + + Also tests that the room can be published to, and removed from, the public room + list. + """ + # We can't invite a user from a forbidden HS. + self.helper.invite( + room=self.restricted_room, + src=self.user_id, + targ="@test:forbidden_domain", + tok=self.tok, + expect_code=403, + ) + + # We can invite a user which HS isn't forbidden. + self.helper.invite( + room=self.restricted_room, + src=self.user_id, + targ="@test:allowed_domain", + tok=self.tok, + expect_code=200, + ) + + # We can't send a 3PID invite to an address that is mapped to a forbidden HS. + self.send_threepid_invite( + address="test@forbidden_domain", + room_id=self.restricted_room, + expected_code=403, + ) + + # We can send a 3PID invite to an address that is mapped to an HS that's not + # forbidden. + self.send_threepid_invite( + address="test@allowed_domain", + room_id=self.restricted_room, + expected_code=200, + ) + + # We are allowed to publish the room to the public room list + url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room + data = {"visibility": "public"} + + request, channel = self.make_request("PUT", url, data, access_token=self.tok) + self.assertEqual(channel.code, 200, channel.result) + + # We are allowed to remove the room from the public room list + url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room + data = {"visibility": "private"} + + request, channel = self.make_request("PUT", url, data, access_token=self.tok) + self.assertEqual(channel.code, 200, channel.result) + + def test_direct(self): + """Tests that, in direct mode, other users than the initial two can't be invited, + but the following scenario works: + * invited user joins the room + * invited user leaves the room + * room creator re-invites invited user + + Tests that a user from a HS that's in the list of forbidden domains (to use + in restricted mode) can be invited. + + Tests that the room cannot be published to the public room list. + """ + not_invited_user = "@not_invited:forbidden_domain" + + # We can't invite a new user to the room. + self.helper.invite( + room=self.direct_rooms[0], + src=self.user_id, + targ=not_invited_user, + tok=self.tok, + expect_code=403, + ) + + # The invited user can join the room. + self.helper.join( + room=self.direct_rooms[0], + user=self.invitee_id, + tok=self.invitee_tok, + expect_code=200, + ) + + # The invited user can leave the room. + self.helper.leave( + room=self.direct_rooms[0], + user=self.invitee_id, + tok=self.invitee_tok, + expect_code=200, + ) + + # The invited user can be re-invited to the room. + self.helper.invite( + room=self.direct_rooms[0], + src=self.user_id, + targ=self.invitee_id, + tok=self.tok, + expect_code=200, + ) + + # If we're alone in the room and have always been the only member, we can invite + # someone. + self.helper.invite( + room=self.direct_rooms[1], + src=self.user_id, + targ=not_invited_user, + tok=self.tok, + expect_code=200, + ) + + # Disable the 3pid invite ratelimiter + burst = self.hs.config.rc_third_party_invite.burst_count + per_second = self.hs.config.rc_third_party_invite.per_second + self.hs.config.rc_third_party_invite.burst_count = 10 + self.hs.config.rc_third_party_invite.per_second = 0.1 + + # We can't send a 3PID invite to a room that already has two members. + self.send_threepid_invite( + address="test@allowed_domain", + room_id=self.direct_rooms[0], + expected_code=403, + ) + + # We can't send a 3PID invite to a room that already has a pending invite. + self.send_threepid_invite( + address="test@allowed_domain", + room_id=self.direct_rooms[1], + expected_code=403, + ) + + # We can send a 3PID invite to a room in which we've always been the only member. + self.send_threepid_invite( + address="test@forbidden_domain", + room_id=self.direct_rooms[2], + expected_code=200, + ) + + # We can send a 3PID invite to a room in which there's a 3PID invite. + self.send_threepid_invite( + address="test@forbidden_domain", + room_id=self.direct_rooms[2], + expected_code=403, + ) + + self.hs.config.rc_third_party_invite.burst_count = burst + self.hs.config.rc_third_party_invite.per_second = per_second + + # We can't publish the room to the public room list + url = "/_matrix/client/r0/directory/list/room/%s" % self.direct_rooms[0] + data = {"visibility": "public"} + + request, channel = self.make_request("PUT", url, data, access_token=self.tok) + self.assertEqual(channel.code, 403, channel.result) + + def test_unrestricted(self): + """Tests that, in unrestricted mode, we can invite whoever we want, but we can + only change the power level of users that wouldn't be forbidden in restricted + mode. + + Tests that the room cannot be published to the public room list. + """ + # We can invite + self.helper.invite( + room=self.unrestricted_room, + src=self.user_id, + targ="@test:forbidden_domain", + tok=self.tok, + expect_code=200, + ) + + self.helper.invite( + room=self.unrestricted_room, + src=self.user_id, + targ="@test:not_forbidden_domain", + tok=self.tok, + expect_code=200, + ) + + # We can send a 3PID invite to an address that is mapped to a forbidden HS. + self.send_threepid_invite( + address="test@forbidden_domain", + room_id=self.unrestricted_room, + expected_code=200, + ) + + # We can send a 3PID invite to an address that is mapped to an HS that's not + # forbidden. + self.send_threepid_invite( + address="test@allowed_domain", + room_id=self.unrestricted_room, + expected_code=200, + ) + + # We can send a power level event that doesn't redefine the default PL or set a + # non-default PL for a user that would be forbidden in restricted mode. + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.PowerLevels, + body={"users": {self.user_id: 100, "@test:not_forbidden_domain": 10}}, + tok=self.tok, + expect_code=200, + ) + + # We can't send a power level event that redefines the default PL and doesn't set + # a non-default PL for a user that would be forbidden in restricted mode. + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.PowerLevels, + body={ + "users": {self.user_id: 100, "@test:not_forbidden_domain": 10}, + "users_default": 10, + }, + tok=self.tok, + expect_code=403, + ) + + # We can't send a power level event that doesn't redefines the default PL but sets + # a non-default PL for a user that would be forbidden in restricted mode. + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.PowerLevels, + body={"users": {self.user_id: 100, "@test:forbidden_domain": 10}}, + tok=self.tok, + expect_code=403, + ) + + # We can't publish the room to the public room list + url = "/_matrix/client/r0/directory/list/room/%s" % self.unrestricted_room + data = {"visibility": "public"} + + request, channel = self.make_request("PUT", url, data, access_token=self.tok) + self.assertEqual(channel.code, 403, channel.result) + + def test_change_rules(self): + """Tests that we can only change the current rule from restricted to + unrestricted. + """ + # We can't change the rule from restricted to direct. + self.change_rule_in_room( + room_id=self.restricted_room, new_rule=AccessRules.DIRECT, expected_code=403 + ) + + # We can change the rule from restricted to unrestricted. + # Note that this changes self.restricted_room to an unrestricted room + self.change_rule_in_room( + room_id=self.restricted_room, + new_rule=AccessRules.UNRESTRICTED, + expected_code=200, + ) + + # We can't change the rule from unrestricted to restricted. + self.change_rule_in_room( + room_id=self.unrestricted_room, + new_rule=AccessRules.RESTRICTED, + expected_code=403, + ) + + # We can't change the rule from unrestricted to direct. + self.change_rule_in_room( + room_id=self.unrestricted_room, + new_rule=AccessRules.DIRECT, + expected_code=403, + ) + + # We can't change the rule from direct to restricted. + self.change_rule_in_room( + room_id=self.direct_rooms[0], + new_rule=AccessRules.RESTRICTED, + expected_code=403, + ) + + # We can't change the rule from direct to unrestricted. + self.change_rule_in_room( + room_id=self.direct_rooms[0], + new_rule=AccessRules.UNRESTRICTED, + expected_code=403, + ) + + # We can't publish a room to the public room list and then change its rule to + # unrestricted + + # Create a restricted room + test_room_id = self.create_room(rule=AccessRules.RESTRICTED) + + # Publish the room to the public room list + url = "/_matrix/client/r0/directory/list/room/%s" % test_room_id + data = {"visibility": "public"} + + request, channel = self.make_request("PUT", url, data, access_token=self.tok) + self.assertEqual(channel.code, 200, channel.result) + + # Attempt to switch the room to "unrestricted" + self.change_rule_in_room( + room_id=test_room_id, new_rule=AccessRules.UNRESTRICTED, expected_code=403 + ) + + # Attempt to switch the room to "direct" + self.change_rule_in_room( + room_id=test_room_id, new_rule=AccessRules.DIRECT, expected_code=403 + ) + + def test_change_room_avatar(self): + """Tests that changing the room avatar is always allowed unless the room is a + direct chat, in which case it's forbidden. + """ + + avatar_content = { + "info": {"h": 398, "mimetype": "image/jpeg", "size": 31037, "w": 394}, + "url": "mxc://example.org/JWEIFJgwEIhweiWJE", + } + + self.helper.send_state( + room_id=self.restricted_room, + event_type=EventTypes.RoomAvatar, + body=avatar_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.RoomAvatar, + body=avatar_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.direct_rooms[0], + event_type=EventTypes.RoomAvatar, + body=avatar_content, + tok=self.tok, + expect_code=403, + ) + + def test_change_room_name(self): + """Tests that changing the room name is always allowed unless the room is a direct + chat, in which case it's forbidden. + """ + + name_content = {"name": "My super room"} + + self.helper.send_state( + room_id=self.restricted_room, + event_type=EventTypes.Name, + body=name_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.Name, + body=name_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.direct_rooms[0], + event_type=EventTypes.Name, + body=name_content, + tok=self.tok, + expect_code=403, + ) + + def test_change_room_topic(self): + """Tests that changing the room topic is always allowed unless the room is a + direct chat, in which case it's forbidden. + """ + + topic_content = {"topic": "Welcome to this room"} + + self.helper.send_state( + room_id=self.restricted_room, + event_type=EventTypes.Topic, + body=topic_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.Topic, + body=topic_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.direct_rooms[0], + event_type=EventTypes.Topic, + body=topic_content, + tok=self.tok, + expect_code=403, + ) + + def test_revoke_3pid_invite_direct(self): + """Tests that revoking a 3PID invite doesn't cause the room access rules module to + confuse the revokation as a new 3PID invite. + """ + invite_token = "sometoken" + + invite_body = { + "display_name": "ker...@exa...", + "public_keys": [ + { + "key_validity_url": "https://validity_url", + "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA", + }, + { + "key_validity_url": "https://validity_url", + "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I", + }, + ], + "key_validity_url": "https://validity_url", + "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA", + } + + self.send_state_with_state_key( + room_id=self.direct_rooms[1], + event_type=EventTypes.ThirdPartyInvite, + state_key=invite_token, + body=invite_body, + tok=self.tok, + ) + + self.send_state_with_state_key( + room_id=self.direct_rooms[1], + event_type=EventTypes.ThirdPartyInvite, + state_key=invite_token, + body={}, + tok=self.tok, + ) + + invite_token = "someothertoken" + + self.send_state_with_state_key( + room_id=self.direct_rooms[1], + event_type=EventTypes.ThirdPartyInvite, + state_key=invite_token, + body=invite_body, + tok=self.tok, + ) + + def test_check_event_allowed(self): + """Tests that RoomAccessRules.check_event_allowed behaves accordingly. + + It tests that: + * forbidden users cannot join restricted rooms. + * forbidden users can only join unrestricted rooms if they have an invite. + """ + event_creator = self.hs.get_event_creation_handler() + + # Test that forbidden users cannot join restricted rooms + requester = create_requester(self.user_id) + allowed_requester = create_requester("@user:allowed_domain") + forbidden_requester = create_requester("@user:forbidden_domain") + + # Assert a join event from a forbidden user to a restricted room is rejected + self.get_failure( + event_creator.create_event( + forbidden_requester, + { + "type": EventTypes.Member, + "room_id": self.restricted_room, + "sender": forbidden_requester.user.to_string(), + "content": {"membership": Membership.JOIN}, + "state_key": forbidden_requester.user.to_string(), + }, + ), + SynapseError, + ) + + # A join event from an non-forbidden user to a restricted room is allowed + self.get_success( + event_creator.create_event( + allowed_requester, + { + "type": EventTypes.Member, + "room_id": self.restricted_room, + "sender": allowed_requester.user.to_string(), + "content": {"membership": Membership.JOIN}, + "state_key": allowed_requester.user.to_string(), + }, + ) + ) + + # Test that forbidden users can only join unrestricted rooms if they have an invite + + # A forbidden user without an invite should not be able to join an unrestricted room + self.get_failure( + event_creator.create_event( + forbidden_requester, + { + "type": EventTypes.Member, + "room_id": self.unrestricted_room, + "sender": forbidden_requester.user.to_string(), + "content": {"membership": Membership.JOIN}, + "state_key": forbidden_requester.user.to_string(), + }, + ), + SynapseError, + ) + + # However, if we then invite this user... + self.helper.invite( + room=self.unrestricted_room, + src=requester.user.to_string(), + targ=forbidden_requester.user.to_string(), + tok=self.tok, + ) + + # And create another join event, making sure that its context states it's coming + # in after the above invite was made... + # Then the forbidden user should be able to join! + self.get_success( + event_creator.create_event( + forbidden_requester, + { + "type": EventTypes.Member, + "room_id": self.unrestricted_room, + "sender": forbidden_requester.user.to_string(), + "content": {"membership": Membership.JOIN}, + "state_key": forbidden_requester.user.to_string(), + }, + ) + ) + + def test_freezing_a_room(self): + """Tests that the power levels in a room change to prevent new events from + non-admin users when the last admin of a room leaves. + """ + + def freeze_room_with_id_and_power_levels( + room_id: str, custom_power_levels_content: Optional[JsonDict] = None, + ): + # Invite a user to the room, they join with PL 0 + self.helper.invite( + room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok, + ) + + # Invitee joins the room + self.helper.join( + room=room_id, user=self.invitee_id, tok=self.invitee_tok, + ) + + if not custom_power_levels_content: + # Retrieve the room's current power levels event content + power_levels = self.helper.get_state( + room_id=room_id, event_type="m.room.power_levels", tok=self.tok, + ) + else: + power_levels = custom_power_levels_content + + # Override the room's power levels with the given power levels content + self.helper.send_state( + room_id=room_id, + event_type="m.room.power_levels", + body=custom_power_levels_content, + tok=self.tok, + ) + + # Ensure that the invitee leaving the room does not change the power levels + self.helper.leave( + room=room_id, user=self.invitee_id, tok=self.invitee_tok, + ) + + # Retrieve the new power levels of the room + new_power_levels = self.helper.get_state( + room_id=room_id, event_type="m.room.power_levels", tok=self.tok, + ) + + # Ensure they have not changed + self.assertDictEqual(power_levels, new_power_levels) + + # Invite the user back again + self.helper.invite( + room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok, + ) + + # Invitee joins the room + self.helper.join( + room=room_id, user=self.invitee_id, tok=self.invitee_tok, + ) + + # Now the admin leaves the room + self.helper.leave( + room=room_id, user=self.user_id, tok=self.tok, + ) + + # Check the power levels again + new_power_levels = self.helper.get_state( + room_id=room_id, event_type="m.room.power_levels", tok=self.invitee_tok, + ) + + # Ensure that the new power levels prevent anyone but admins from sending + # certain events + self.assertEquals(new_power_levels["state_default"], 100) + self.assertEquals(new_power_levels["events_default"], 100) + self.assertEquals(new_power_levels["kick"], 100) + self.assertEquals(new_power_levels["invite"], 100) + self.assertEquals(new_power_levels["ban"], 100) + self.assertEquals(new_power_levels["redact"], 100) + self.assertDictEqual(new_power_levels["events"], {}) + self.assertDictEqual(new_power_levels["users"], {self.user_id: 100}) + + # Ensure new users entering the room aren't going to immediately become admins + self.assertEquals(new_power_levels["users_default"], 0) + + # Test that freezing a room with the default power level state event content works + room1 = self.create_room() + freeze_room_with_id_and_power_levels(room1) + + # Test that freezing a room with a power level state event that is missing + # `state_default` and `event_default` keys behaves as expected + room2 = self.create_room() + freeze_room_with_id_and_power_levels( + room2, + { + "ban": 50, + "events": { + "m.room.avatar": 50, + "m.room.canonical_alias": 50, + "m.room.history_visibility": 100, + "m.room.name": 50, + "m.room.power_levels": 100, + }, + "invite": 0, + "kick": 50, + "redact": 50, + "users": {self.user_id: 100}, + "users_default": 0, + # Explicitly remove `state_default` and `event_default` keys + }, + ) + + # Test that freezing a room with a power level state event that is *additionally* + # missing `ban`, `invite`, `kick` and `redact` keys behaves as expected + room3 = self.create_room() + freeze_room_with_id_and_power_levels( + room3, + { + "events": { + "m.room.avatar": 50, + "m.room.canonical_alias": 50, + "m.room.history_visibility": 100, + "m.room.name": 50, + "m.room.power_levels": 100, + }, + "users": {self.user_id: 100}, + "users_default": 0, + # Explicitly remove `state_default` and `event_default` keys + # Explicitly remove `ban`, `invite`, `kick` and `redact` keys + }, + ) + + def create_room( + self, + direct=False, + rule=None, + preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT, + initial_state=None, + power_levels_content_override=None, + expected_code=200, + ): + content = {"is_direct": direct, "preset": preset} + + if rule: + content["initial_state"] = [ + {"type": ACCESS_RULES_TYPE, "state_key": "", "content": {"rule": rule}} + ] + + if initial_state: + if "initial_state" not in content: + content["initial_state"] = [] + + content["initial_state"] += initial_state + + if power_levels_content_override: + content["power_levels_content_override"] = power_levels_content_override + + request, channel = self.make_request( + "POST", "/_matrix/client/r0/createRoom", content, access_token=self.tok, + ) + + self.assertEqual(channel.code, expected_code, channel.result) + + if expected_code == 200: + return channel.json_body["room_id"] + + def current_rule_in_room(self, room_id): + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE), + access_token=self.tok, + ) + + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body["rule"] + + def change_rule_in_room(self, room_id, new_rule, expected_code=200): + data = {"rule": new_rule} + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE), + json.dumps(data), + access_token=self.tok, + ) + + self.assertEqual(channel.code, expected_code, channel.result) + + def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200): + data = {"join_rule": new_join_rule} + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules), + json.dumps(data), + access_token=self.tok, + ) + + self.assertEqual(channel.code, expected_code, channel.result) + + def send_threepid_invite(self, address, room_id, expected_code=200): + params = {"id_server": "testis", "medium": "email", "address": address} + + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/invite" % room_id, + json.dumps(params), + access_token=self.tok, + ) + self.assertEqual(channel.code, expected_code, channel.result) + + def send_state_with_state_key( + self, room_id, event_type, state_key, body, tok, expect_code=200 + ): + path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % ( + room_id, + event_type, + state_key, + ) + + request, channel = self.make_request( + "PUT", path, json.dumps(body), access_token=tok + ) + + self.assertEqual(channel.code, expect_code, channel.result) + + return channel.json_body diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 6bb02b9630..94dcfb9f7c 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py
@@ -95,7 +95,6 @@ class RoomTestCase(_ShadowBannedBase): {"id_server": "test", "medium": "email", "address": "test@test.test"}, access_token=self.banned_access_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.result) # This should have raised an error earlier, but double check this wasn't called. @@ -110,7 +109,6 @@ class RoomTestCase(_ShadowBannedBase): {"visibility": "public", "invite": [self.other_user_id]}, access_token=self.banned_access_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.result) room_id = channel.json_body["room_id"] @@ -166,7 +164,6 @@ class RoomTestCase(_ShadowBannedBase): {"new_version": "6"}, access_token=self.banned_access_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.result) # A new room_id should be returned. self.assertIn("replacement_room", channel.json_body) @@ -192,7 +189,6 @@ class RoomTestCase(_ShadowBannedBase): {"typing": True, "timeout": 30000}, access_token=self.banned_access_token, ) - self.render(request) self.assertEquals(200, channel.code) # There should be no typing events. @@ -208,7 +204,6 @@ class RoomTestCase(_ShadowBannedBase): {"typing": True, "timeout": 30000}, access_token=self.other_access_token, ) - self.render(request) self.assertEquals(200, channel.code) # These appear in the room. @@ -255,7 +250,6 @@ class ProfileTestCase(_ShadowBannedBase): {"displayname": new_display_name}, access_token=self.banned_access_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertEqual(channel.json_body, {}) @@ -263,7 +257,6 @@ class ProfileTestCase(_ShadowBannedBase): request, channel = self.make_request( "GET", "/profile/%s/displayname" % (self.banned_user_id,) ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.json_body["displayname"], new_display_name) @@ -296,7 +289,6 @@ class ProfileTestCase(_ShadowBannedBase): {"membership": "join", "displayname": new_display_name}, access_token=self.banned_access_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertIn("event_id", channel.json_body) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 0048bea54a..0e96697f9b 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py
@@ -92,7 +92,6 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): {}, access_token=self.tok, ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) callback.assert_called_once() @@ -111,7 +110,6 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): {}, access_token=self.tok, ) - self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) def test_cannot_modify_event(self): @@ -131,7 +129,6 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): {"x": "x"}, access_token=self.tok, ) - self.render(request) self.assertEqual(channel.result["code"], b"500", channel.result) def test_modify_event(self): @@ -151,7 +148,6 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): {"x": "x"}, access_token=self.tok, ) - self.render(request) self.assertEqual(channel.result["code"], b"200", channel.result) event_id = channel.json_body["event_id"] @@ -161,7 +157,6 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, ) - self.render(request) self.assertEqual(channel.result["code"], b"200", channel.result) ev = channel.json_body self.assertEqual(ev["content"]["x"], "y") diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py
index ea5a7f3739..7a2c653df8 100644 --- a/tests/rest/client/v1/test_directory.py +++ b/tests/rest/client/v1/test_directory.py
@@ -94,7 +94,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) - self.render(request) self.assertEqual(channel.code, 400, channel.result) def test_room_creation(self): @@ -108,7 +107,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) def set_alias_via_state_event(self, expected_code, alias_length=5): @@ -123,7 +121,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok ) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) def set_alias_via_directory(self, expected_code, alias_length=5): @@ -134,7 +131,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok ) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) def random_alias(self, length): diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 3397ba5579..12a93f5687 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py
@@ -66,14 +66,12 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/events?access_token=%s" % ("invalid" + self.token,) ) - self.render(request) self.assertEquals(channel.code, 401, msg=channel.result) # valid token, expect content request, channel = self.make_request( "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) - self.render(request) self.assertEquals(channel.code, 200, msg=channel.result) self.assertTrue("chunk" in channel.json_body) self.assertTrue("start" in channel.json_body) @@ -92,7 +90,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) - self.render(request) self.assertEquals(channel.code, 200, msg=channel.result) # We may get a presence event for ourselves down @@ -155,5 +152,4 @@ class GetEventsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/events/" + event_id, access_token=self.token, ) - self.render(request) self.assertEquals(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 5d987a30c7..ee75f4076f 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py
@@ -64,7 +64,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "password": "monkey", } request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -84,7 +83,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "password": "monkey", } request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -111,7 +109,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "password": "monkey", } request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -131,7 +128,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "password": "monkey", } request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -158,7 +154,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "password": "notamonkey", } request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -178,7 +173,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "password": "notamonkey", } request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) @@ -188,7 +182,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we shouldn't be able to make requests without an access token request, channel = self.make_request(b"GET", TEST_URL) - self.render(request) self.assertEquals(channel.result["code"], b"401", channel.result) self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN") @@ -199,7 +192,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "password": "monkey", } request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) self.assertEquals(channel.code, 200, channel.result) access_token = channel.json_body["access_token"] @@ -209,7 +201,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( b"GET", TEST_URL, access_token=access_token ) - self.render(request) self.assertEquals(channel.code, 200, channel.result) # time passes @@ -219,7 +210,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( b"GET", TEST_URL, access_token=access_token ) - self.render(request) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], True) @@ -236,7 +226,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( b"GET", TEST_URL, access_token=access_token ) - self.render(request) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], True) @@ -247,7 +236,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( b"GET", TEST_URL, access_token=access_token ) - self.render(request) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], False) @@ -257,7 +245,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token ) - self.render(request) self.assertEquals(channel.code, 401, channel.result) # check it's a UI-Auth fail self.assertEqual( @@ -281,7 +268,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): access_token=access_token, content={"auth": auth}, ) - self.render(request) self.assertEquals(channel.code, 200, channel.result) @override_config({"session_lifetime": "24h"}) @@ -295,7 +281,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( b"GET", TEST_URL, access_token=access_token ) - self.render(request) self.assertEquals(channel.code, 200, channel.result) # time passes @@ -305,7 +290,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( b"GET", TEST_URL, access_token=access_token ) - self.render(request) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], True) @@ -314,7 +298,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( b"POST", "/logout", access_token=access_token ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @override_config({"session_lifetime": "24h"}) @@ -328,7 +311,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( b"GET", TEST_URL, access_token=access_token ) - self.render(request) self.assertEquals(channel.code, 200, channel.result) # time passes @@ -338,7 +320,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( b"GET", TEST_URL, access_token=access_token ) - self.render(request) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], True) @@ -347,7 +328,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( b"POST", "/logout/all", access_token=access_token ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -423,7 +403,6 @@ class CASTestCase(unittest.HomeserverTestCase): # Get Synapse to call the fake CAS and serve the template. request, channel = self.make_request("GET", cas_ticket_url) - self.render(request) # Test that the response is HTML. self.assertEqual(channel.code, 200) @@ -468,7 +447,6 @@ class CASTestCase(unittest.HomeserverTestCase): # Get Synapse to call the fake CAS and serve the template. request, channel = self.make_request("GET", cas_ticket_url) - self.render(request) self.assertEqual(channel.code, 302) location_headers = channel.headers.getRawHeaders("Location") @@ -495,7 +473,6 @@ class CASTestCase(unittest.HomeserverTestCase): # Get Synapse to call the fake CAS and serve the template. request, channel = self.make_request("GET", cas_ticket_url) - self.render(request) # Because the user is deactivated they are served an error template. self.assertEqual(channel.code, 403) @@ -518,15 +495,18 @@ class JWTTestCase(unittest.HomeserverTestCase): self.hs.config.jwt_algorithm = self.jwt_algorithm return self.hs - def jwt_encode(self, token, secret=jwt_secret): - return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii") + def jwt_encode(self, token: str, secret: str = jwt_secret) -> str: + # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. + result = jwt.encode(token, secret, self.jwt_algorithm) + if isinstance(result, bytes): + return result.decode("ascii") + return result def jwt_login(self, *args): params = json.dumps( {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} ) request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) return channel def test_login_jwt_valid_registered(self): @@ -659,7 +639,6 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_no_token(self): params = json.dumps({"type": "org.matrix.login.jwt"}) request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") @@ -725,15 +704,18 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): self.hs.config.jwt_algorithm = "RS256" return self.hs - def jwt_encode(self, token, secret=jwt_privatekey): - return jwt.encode(token, secret, "RS256").decode("ascii") + def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str: + # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. + result = jwt.encode(token, secret, "RS256") + if isinstance(result, bytes): + return result.decode("ascii") + return result def jwt_login(self, *args): params = json.dumps( {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} ) request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) return channel def test_login_jwt_valid(self): @@ -766,7 +748,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/register?access_token=%s" % (self.service.token,), {"username": username}, ) - self.render(request) def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver() @@ -815,7 +796,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) def test_login_appservice_user_bot(self): @@ -831,7 +811,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) def test_login_appservice_wrong_user(self): @@ -847,7 +826,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) def test_login_appservice_wrong_as(self): @@ -863,7 +841,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.another_service.token ) - self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) def test_login_appservice_no_token(self): @@ -878,5 +855,4 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): } request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) self.assertEquals(channel.result["code"], b"401", channel.result) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 3c66255dac..5d5c24d01c 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py
@@ -33,13 +33,16 @@ class PresenceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): + presence_handler = Mock() + presence_handler.set_state.return_value = defer.succeed(None) + hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock() + "red", + http_client=None, + federation_client=Mock(), + presence_handler=presence_handler, ) - hs.presence_handler = Mock() - hs.presence_handler.set_state.return_value = defer.succeed(None) - return hs def test_put_presence(self): @@ -53,10 +56,9 @@ class PresenceTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", "/presence/%s/status" % (self.user_id,), body ) - self.render(request) self.assertEqual(channel.code, 200) - self.assertEqual(self.hs.presence_handler.set_state.call_count, 1) + self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) def test_put_presence_disabled(self): """ @@ -69,7 +71,6 @@ class PresenceTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", "/presence/%s/status" % (self.user_id,), body ) - self.render(request) self.assertEqual(channel.code, 200) - self.assertEqual(self.hs.presence_handler.set_state.call_count, 0) + self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index ace0a3c08d..383a9eafac 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py
@@ -195,7 +195,6 @@ class ProfileTestCase(unittest.HomeserverTestCase): content=json.dumps({"displayname": "test"}), access_token=self.owner_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) res = self.get_displayname() @@ -209,7 +208,6 @@ class ProfileTestCase(unittest.HomeserverTestCase): content=json.dumps({"displayname": "test" * 100}), access_token=self.owner_tok, ) - self.render(request) self.assertEqual(channel.code, 400, channel.result) res = self.get_displayname() @@ -219,7 +217,6 @@ class ProfileTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/profile/%s/displayname" % (self.owner,) ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) return channel.json_body["displayname"] @@ -284,7 +281,6 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.profile_url + url_suffix, access_token=access_token ) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) def ensure_requester_left_room(self): @@ -327,7 +323,6 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/profile/" + self.requester, access_token=self.requester_tok ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) request, channel = self.make_request( @@ -335,7 +330,6 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): "/profile/" + self.requester + "/displayname", access_token=self.requester_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) request, channel = self.make_request( @@ -343,5 +337,4 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): "/profile/" + self.requester + "/avatar_url", access_token=self.requester_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py
index 081052f6a6..7add5523c8 100644 --- a/tests/rest/client/v1/test_push_rule_attrs.py +++ b/tests/rest/client/v1/test_push_rule_attrs.py
@@ -48,14 +48,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase): request, channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # GET enabled for that new rule request, channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) @@ -79,7 +77,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase): request, channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # disable the rule @@ -89,14 +86,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase): {"enabled": False}, access_token=token, ) - self.render(request) self.assertEqual(channel.code, 200) # check rule disabled request, channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], False) @@ -104,21 +99,18 @@ class PushRuleAttributesTestCase(HomeserverTestCase): request, channel = self.make_request( "DELETE", "/pushrules/global/override/best.friend", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # PUT a new rule request, channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # GET enabled for that new rule request, channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) @@ -141,7 +133,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase): request, channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # disable the rule @@ -151,14 +142,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase): {"enabled": False}, access_token=token, ) - self.render(request) self.assertEqual(channel.code, 200) # check rule disabled request, channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], False) @@ -169,14 +158,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase): {"enabled": True}, access_token=token, ) - self.render(request) self.assertEqual(channel.code, 200) # check rule enabled request, channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) @@ -198,7 +185,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase): request, channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -206,28 +192,24 @@ class PushRuleAttributesTestCase(HomeserverTestCase): request, channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # GET enabled for that new rule request, channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # DELETE the rule request, channel = self.make_request( "DELETE", "/pushrules/global/override/best.friend", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # check 404 for deleted rule request, channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -242,7 +224,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase): request, channel = self.make_request( "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -260,7 +241,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase): {"enabled": True}, access_token=token, ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -278,7 +258,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase): {"enabled": True}, access_token=token, ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -300,14 +279,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase): request, channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # GET actions for that new rule request, channel = self.make_request( "GET", "/pushrules/global/override/best.friend/actions", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) self.assertEqual( channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}] @@ -331,7 +308,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase): request, channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # change the rule actions @@ -341,14 +317,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase): {"actions": ["dont_notify"]}, access_token=token, ) - self.render(request) self.assertEqual(channel.code, 200) # GET actions for that new rule request, channel = self.make_request( "GET", "/pushrules/global/override/best.friend/actions", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["actions"], ["dont_notify"]) @@ -370,7 +344,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase): request, channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -378,21 +351,18 @@ class PushRuleAttributesTestCase(HomeserverTestCase): request, channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # DELETE the rule request, channel = self.make_request( "DELETE", "/pushrules/global/override/best.friend", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # check 404 for deleted rule request, channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -407,7 +377,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase): request, channel = self.make_request( "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -425,7 +394,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase): {"actions": ["dont_notify"]}, access_token=token, ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -443,6 +411,5 @@ class PushRuleAttributesTestCase(HomeserverTestCase): {"actions": ["dont_notify"]}, access_token=token, ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 9ba5f9d943..49f1073c88 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py
@@ -86,7 +86,6 @@ class RoomPermissionsTestCase(RoomBase): request, channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) - self.render(request) self.assertEquals(200, channel.code, channel.result) # set topic for public room @@ -95,7 +94,6 @@ class RoomPermissionsTestCase(RoomBase): ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', ) - self.render(request) self.assertEquals(200, channel.code, channel.result) # auth as user_id now @@ -118,12 +116,10 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, ) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 request, channel = self.make_request("PUT", send_msg_path(), msg_content) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 @@ -131,19 +127,16 @@ class RoomPermissionsTestCase(RoomBase): room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) request, channel = self.make_request("PUT", send_msg_path(), msg_content) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) request, channel = self.make_request("PUT", send_msg_path(), msg_content) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) request, channel = self.make_request("PUT", send_msg_path(), msg_content) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_topic_perms(self): @@ -154,20 +147,16 @@ class RoomPermissionsTestCase(RoomBase): request, channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) request, channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 request, channel = self.make_request("PUT", topic_path, topic_content) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) request, channel = self.make_request("GET", topic_path) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 @@ -175,12 +164,10 @@ class RoomPermissionsTestCase(RoomBase): room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) request, channel = self.make_request("PUT", topic_path, topic_content) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 request, channel = self.make_request("GET", topic_path) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 @@ -189,29 +176,24 @@ class RoomPermissionsTestCase(RoomBase): # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id request, channel = self.make_request("PUT", topic_path, topic_content) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id request, channel = self.make_request("GET", topic_path) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) request, channel = self.make_request("PUT", topic_path, topic_content) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) request, channel = self.make_request("GET", topic_path) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 request, channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 @@ -220,14 +202,12 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, ) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) def _test_get_membership(self, room=None, members=[], expect_code=None): for member in members: path = "/rooms/%s/state/m.room.member/%s" % (room, member) request, channel = self.make_request("GET", path) - self.render(request) self.assertEquals(expect_code, channel.code) def test_membership_basic_room_perms(self): @@ -400,18 +380,15 @@ class RoomsMemberListTestCase(RoomBase): def test_get_member_list(self): room_id = self.helper.create_room_as(self.user_id) request, channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self): request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self): room_id = self.helper.create_room_as("@some_other_guy:red") request, channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self): @@ -421,19 +398,16 @@ class RoomsMemberListTestCase(RoomBase): self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. request, channel = self.make_request("GET", room_path) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined request, channel = self.make_request("GET", room_path) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left request, channel = self.make_request("GET", room_path) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) @@ -446,7 +420,6 @@ class RoomsCreateTestCase(RoomBase): # POST with no config keys, expect new room id request, channel = self.make_request("POST", "/createRoom", "{}") - self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) @@ -455,7 +428,6 @@ class RoomsCreateTestCase(RoomBase): request, channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private"}' ) - self.render(request) self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) @@ -464,7 +436,6 @@ class RoomsCreateTestCase(RoomBase): request, channel = self.make_request( "POST", "/createRoom", b'{"custom":"stuff"}' ) - self.render(request) self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) @@ -473,18 +444,15 @@ class RoomsCreateTestCase(RoomBase): request, channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) - self.render(request) self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_invalid_content(self): # POST with invalid content / paths, expect 400 request, channel = self.make_request("POST", "/createRoom", b'{"visibili') - self.render(request) self.assertEquals(400, channel.code) request, channel = self.make_request("POST", "/createRoom", b'["hello"]') - self.render(request) self.assertEquals(400, channel.code) def test_post_room_invitees_invalid_mxid(self): @@ -493,7 +461,6 @@ class RoomsCreateTestCase(RoomBase): request, channel = self.make_request( "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' ) - self.render(request) self.assertEquals(400, channel.code) @@ -510,52 +477,42 @@ class RoomTopicTestCase(RoomBase): def test_invalid_puts(self): # missing keys or invalid json request, channel = self.make_request("PUT", self.path, "{}") - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}') - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) request, channel = self.make_request("PUT", self.path, '{"nao') - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) request, channel = self.make_request( "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) request, channel = self.make_request("PUT", self.path, "text only") - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) request, channel = self.make_request("PUT", self.path, "") - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) # valid key, wrong type content = '{"topic":["Topic name"]}' request, channel = self.make_request("PUT", self.path, content) - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_topic(self): # nothing should be there request, channel = self.make_request("GET", self.path) - self.render(request) self.assertEquals(404, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' request, channel = self.make_request("PUT", self.path, content) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) # valid get request, channel = self.make_request("GET", self.path) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) @@ -563,12 +520,10 @@ class RoomTopicTestCase(RoomBase): # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' request, channel = self.make_request("PUT", self.path, content) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) # valid get request, channel = self.make_request("GET", self.path) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) @@ -585,29 +540,23 @@ class RoomMemberStateTestCase(RoomBase): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json request, channel = self.make_request("PUT", path, "{}") - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) request, channel = self.make_request("PUT", path, '{"_name":"bo"}') - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) request, channel = self.make_request("PUT", path, '{"nao') - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) request, channel = self.make_request( "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]' ) - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) request, channel = self.make_request("PUT", path, "text only") - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) request, channel = self.make_request("PUT", path, "") - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) # valid keys, wrong types @@ -617,7 +566,6 @@ class RoomMemberStateTestCase(RoomBase): Membership.LEAVE, ) request, channel = self.make_request("PUT", path, content.encode("ascii")) - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_members_self(self): @@ -629,11 +577,9 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN request, channel = self.make_request("PUT", path, content.encode("ascii")) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) request, channel = self.make_request("GET", path, None) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} @@ -649,11 +595,9 @@ class RoomMemberStateTestCase(RoomBase): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE request, channel = self.make_request("PUT", path, content) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) request, channel = self.make_request("GET", path, None) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assertEquals(json.loads(content), channel.json_body) @@ -670,11 +614,9 @@ class RoomMemberStateTestCase(RoomBase): "Join us!", ) request, channel = self.make_request("PUT", path, content) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) request, channel = self.make_request("GET", path, None) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assertEquals(json.loads(content), channel.json_body) @@ -725,7 +667,6 @@ class RoomJoinRatelimitTestCase(RoomBase): # Update the display name for the user. path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id request, channel = self.make_request("PUT", path, {"displayname": "John Doe"}) - self.render(request) self.assertEquals(channel.code, 200, channel.json_body) # Check that all the rooms have been sent a profile update into. @@ -736,7 +677,6 @@ class RoomJoinRatelimitTestCase(RoomBase): ) request, channel = self.make_request("GET", path) - self.render(request) self.assertEquals(channel.code, 200) self.assertIn("displayname", channel.json_body) @@ -761,7 +701,6 @@ class RoomJoinRatelimitTestCase(RoomBase): # if all of these requests ended up joining the user to a room. for i in range(4): request, channel = self.make_request("POST", path % room_id, {}) - self.render(request) self.assertEquals(channel.code, 200) @@ -777,29 +716,23 @@ class RoomMessagesTestCase(RoomBase): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json request, channel = self.make_request("PUT", path, b"{}") - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) request, channel = self.make_request("PUT", path, b'{"_name":"bo"}') - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) request, channel = self.make_request("PUT", path, b'{"nao') - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) request, channel = self.make_request( "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]' ) - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) request, channel = self.make_request("PUT", path, b"text only") - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) request, channel = self.make_request("PUT", path, b"") - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_messages_sent(self): @@ -807,20 +740,17 @@ class RoomMessagesTestCase(RoomBase): content = b'{"body":"test","msgtype":{"type":"a"}}' request, channel = self.make_request("PUT", path, content) - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) # custom message types content = b'{"body":"test","msgtype":"test.custom.text"}' request, channel = self.make_request("PUT", path, content) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = b'{"body":"test2","msgtype":"m.text"}' request, channel = self.make_request("PUT", path, content) - self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) @@ -837,7 +767,6 @@ class RoomInitialSyncTestCase(RoomBase): request, channel = self.make_request( "GET", "/rooms/%s/initialSync" % self.room_id ) - self.render(request) self.assertEquals(200, channel.code) self.assertEquals(self.room_id, channel.json_body["room_id"]) @@ -881,7 +810,6 @@ class RoomMessageListTestCase(RoomBase): request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.render(request) self.assertEquals(200, channel.code) self.assertTrue("start" in channel.json_body) self.assertEquals(token, channel.json_body["start"]) @@ -893,7 +821,6 @@ class RoomMessageListTestCase(RoomBase): request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.render(request) self.assertEquals(200, channel.code) self.assertTrue("start" in channel.json_body) self.assertEquals(token, channel.json_body["start"]) @@ -933,7 +860,6 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.render(request) self.assertEqual(channel.code, 200, channel.json_body) chunk = channel.json_body["chunk"] @@ -962,7 +888,6 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.render(request) self.assertEqual(channel.code, 200, channel.json_body) chunk = channel.json_body["chunk"] @@ -980,7 +905,6 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.render(request) self.assertEqual(channel.code, 200, channel.json_body) chunk = channel.json_body["chunk"] @@ -1040,7 +964,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): } }, ) - self.render(request) # Check we get the results we expect -- one search result, of the sent # messages @@ -1074,7 +997,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): } }, ) - self.render(request) # Check we get the results we expect -- one search result, of the sent # messages @@ -1111,7 +1033,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): def test_restricted_no_auth(self): request, channel = self.make_request("GET", self.url) - self.render(request) self.assertEqual(channel.code, 401, channel.result) def test_restricted_auth(self): @@ -1119,7 +1040,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): tok = self.login("user", "pass") request, channel = self.make_request("GET", self.url, access_token=tok) - self.render(request) self.assertEqual(channel.code, 200, channel.result) @@ -1153,7 +1073,6 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): request_data, access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) @@ -1168,7 +1087,6 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): request_data, access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) event_id = channel.json_body["event_id"] @@ -1177,7 +1095,6 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) res_displayname = channel.json_body["content"]["displayname"] @@ -1212,7 +1129,6 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self._check_for_reason(reason) @@ -1227,7 +1143,6 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self._check_for_reason(reason) @@ -1242,7 +1157,6 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.second_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self._check_for_reason(reason) @@ -1257,7 +1171,6 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self._check_for_reason(reason) @@ -1270,7 +1183,6 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self._check_for_reason(reason) @@ -1283,7 +1195,6 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self._check_for_reason(reason) @@ -1303,7 +1214,6 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self._check_for_reason(reason) @@ -1316,7 +1226,6 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): ), access_token=self.creator_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) event_content = channel.json_body @@ -1365,7 +1274,6 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) events_before = channel.json_body["events_before"] @@ -1396,7 +1304,6 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) events_before = channel.json_body["events_before"] @@ -1432,7 +1339,6 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) events_before = channel.json_body["events_before"] @@ -1460,7 +1366,6 @@ class LabelsTestCase(unittest.HomeserverTestCase): "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), ) - self.render(request) events = channel.json_body["chunk"] @@ -1478,7 +1383,6 @@ class LabelsTestCase(unittest.HomeserverTestCase): "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), ) - self.render(request) events = channel.json_body["chunk"] @@ -1507,7 +1411,6 @@ class LabelsTestCase(unittest.HomeserverTestCase): json.dumps(self.FILTER_LABELS_NOT_LABELS), ), ) - self.render(request) events = channel.json_body["chunk"] @@ -1532,7 +1435,6 @@ class LabelsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", "/search?access_token=%s" % self.tok, request_data ) - self.render(request) results = channel.json_body["search_categories"]["room_events"]["results"] @@ -1568,7 +1470,6 @@ class LabelsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", "/search?access_token=%s" % self.tok, request_data ) - self.render(request) results = channel.json_body["search_categories"]["room_events"]["results"] @@ -1616,7 +1517,6 @@ class LabelsTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", "/search?access_token=%s" % self.tok, request_data ) - self.render(request) results = channel.json_body["search_categories"]["room_events"]["results"] @@ -1741,7 +1641,6 @@ class ContextTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id), access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) events_before = channel.json_body["events_before"] @@ -1806,7 +1705,6 @@ class ContextTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id), access_token=invited_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) events_before = channel.json_body["events_before"] @@ -1897,7 +1795,6 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): % (self.room_id,), access_token=access_token, ) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) res = channel.json_body self.assertIsInstance(res, dict) @@ -1916,7 +1813,6 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok ) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) @@ -1947,7 +1843,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok ) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: @@ -1957,7 +1852,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), access_token=self.room_owner_tok, ) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) res = channel.json_body self.assertIsInstance(res, dict) @@ -1971,7 +1865,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): json.dumps(content), access_token=self.room_owner_tok, ) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) res = channel.json_body self.assertIsInstance(res, dict) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index cd58ee7792..bbd30f594b 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py
@@ -99,7 +99,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.render(request) self.assertEquals(200, channel.code) self.assertEquals(self.event_source.get_current_key(), 1) @@ -123,7 +122,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": false}', ) - self.render(request) self.assertEquals(200, channel.code) def test_typing_timeout(self): @@ -132,7 +130,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.render(request) self.assertEquals(200, channel.code) self.assertEquals(self.event_source.get_current_key(), 1) @@ -146,7 +143,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.render(request) self.assertEquals(200, channel.code) self.assertEquals(self.event_source.get_current_key(), 3) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index afaf9f7b85..737c38c396 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py
@@ -23,10 +23,11 @@ from typing import Any, Dict, Optional import attr from twisted.web.resource import Resource +from twisted.web.server import Site from synapse.api.constants import Membership -from tests.server import make_request, render +from tests.server import FakeSite, make_request @attr.s @@ -36,25 +37,51 @@ class RestHelper: """ hs = attr.ib() - resource = attr.ib() + site = attr.ib(type=Site) auth_user_id = attr.ib() def create_room_as( - self, room_creator=None, is_public=True, tok=None, expect_code=200, - ): + self, + room_creator: str = None, + is_public: bool = True, + room_version: str = None, + tok: str = None, + expect_code: int = 200, + ) -> str: + """ + Create a room. + + Args: + room_creator: The user ID to create the room with. + is_public: If True, the `visibility` parameter will be set to the + default (public). Otherwise, the `visibility` parameter will be set + to "private". + room_version: The room version to create the room as. Defaults to Synapse's + default room version. + tok: The access token to use in the request. + expect_code: The expected HTTP response code. + + Returns: + The ID of the newly created room. + """ temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" content = {} if not is_public: content["visibility"] = "private" + if room_version: + content["room_version"] = room_version if tok: path = path + "?access_token=%s" % tok - request, channel = make_request( - self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8") + _, channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + path, + json.dumps(content).encode("utf8"), ) - render(request, self.resource, self.hs.get_reactor()) assert channel.result["code"] == b"%d" % expect_code, channel.result self.auth_user_id = temp_id @@ -124,12 +151,14 @@ class RestHelper: data = {"membership": membership} data.update(extra_data) - request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8") + _, channel = make_request( + self.hs.get_reactor(), + self.site, + "PUT", + path, + json.dumps(data).encode("utf8"), ) - render(request, self.resource, self.hs.get_reactor()) - assert int(channel.result["code"]) == expect_code, ( "Expected: %d, got: %d, resp: %r" % (expect_code, int(channel.result["code"]), channel.result["body"]) @@ -157,10 +186,13 @@ class RestHelper: if tok: path = path + "?access_token=%s" % tok - request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8") + _, channel = make_request( + self.hs.get_reactor(), + self.site, + "PUT", + path, + json.dumps(content).encode("utf8"), ) - render(request, self.resource, self.hs.get_reactor()) assert int(channel.result["code"]) == expect_code, ( "Expected: %d, got: %d, resp: %r" @@ -210,9 +242,9 @@ class RestHelper: if body is not None: content = json.dumps(body).encode("utf8") - request, channel = make_request(self.hs.get_reactor(), method, path, content) - - render(request, self.resource, self.hs.get_reactor()) + _, channel = make_request( + self.hs.get_reactor(), self.site, method, path, content + ) assert int(channel.result["code"]) == expect_code, ( "Expected: %d, got: %d, resp: %r" @@ -295,14 +327,15 @@ class RestHelper: """ image_length = len(image_data) path = "/_matrix/media/r0/upload?filename=%s" % (filename,) - request, channel = make_request( - self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok - ) - request.requestHeaders.addRawHeader( - b"Content-Length", str(image_length).encode("UTF-8") + _, channel = make_request( + self.hs.get_reactor(), + FakeSite(resource), + "POST", + path, + content=image_data, + access_token=tok, + custom_headers=[(b"Content-Length", str(image_length))], ) - request.render(resource) - self.hs.get_reactor().pump([100]) assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 66ac4dbe85..2ac1ecb7d3 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py
@@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import account, register from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from tests import unittest +from tests.server import FakeSite, make_request from tests.unittest import override_config @@ -245,7 +246,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): b"account/password/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, ) - self.render(request) self.assertEquals(200, channel.code, channel.result) return channel.json_body["sid"] @@ -255,9 +255,14 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): path = link.replace("https://example.com", "") # Load the password reset confirmation page - request, channel = self.make_request("GET", path, shorthand=False) - request.render(self.submit_token_resource) - self.pump() + request, channel = make_request( + self.reactor, + FakeSite(self.submit_token_resource), + "GET", + path, + shorthand=False, + ) + self.assertEquals(200, channel.code, channel.result) # Now POST to the same endpoint, mimicking the same behaviour as clicking the @@ -271,15 +276,15 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): form_args.append(arg) # Confirm the password reset - request, channel = self.make_request( + request, channel = make_request( + self.reactor, + FakeSite(self.submit_token_resource), "POST", path, content=urlencode(form_args).encode("utf8"), shorthand=False, content_is_form=True, ) - request.render(self.submit_token_resource) - self.pump() self.assertEquals(200, channel.code, channel.result) def _get_link_from_email(self): @@ -319,7 +324,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): }, }, ) - self.render(request) self.assertEquals(expected_code, channel.code, channel.result) @@ -349,7 +353,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase): # Check that this access token has been invalidated. request, channel = self.make_request("GET", "account/whoami") - self.render(request) self.assertEqual(request.code, 401) def test_pending_invites(self): @@ -407,7 +410,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) - self.render(request) self.assertEqual(request.code, 200) @@ -542,7 +544,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -550,7 +551,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @@ -575,14 +575,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Get user request, channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @@ -609,7 +607,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -618,7 +615,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -647,7 +643,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) @@ -655,7 +650,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @@ -682,7 +676,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) @@ -690,7 +683,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @@ -795,7 +787,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", b"account/3pid/email/requestToken", body, ) - self.render(request) self.assertEquals(expect_code, channel.code, channel.result) return channel.json_body.get("sid") @@ -808,7 +799,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): b"account/3pid/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(expected_errcode, channel.json_body["errcode"]) self.assertEqual(expected_error, channel.json_body["error"]) @@ -818,7 +808,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): path = link.replace("https://example.com", "") request, channel = self.make_request("GET", path, shorthand=False) - self.render(request) self.assertEquals(200, channel.code, channel.result) def _get_link_from_email(self): @@ -867,14 +856,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Get user request, channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 86184f0d2e..77246e478f 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -38,11 +38,6 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker): return succeed(True) -class DummyPasswordChecker(UserInteractiveAuthChecker): - def check_auth(self, authdict, clientip): - return succeed(authdict["identifier"]["user"]) - - class FallbackAuthTests(unittest.HomeserverTestCase): servlets = [ @@ -72,7 +67,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", "register", body ) # type: SynapseRequest, FakeChannel - self.render(request) self.assertEqual(request.code, expected_response) return channel @@ -87,7 +81,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "auth/m.login.recaptcha/fallback/web?session=" + session ) # type: SynapseRequest, FakeChannel - self.render(request) self.assertEqual(request.code, 200) request, channel = self.make_request( @@ -96,7 +89,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase): + post_session + "&g-recaptcha-response=a", ) - self.render(request) self.assertEqual(request.code, expected_post_response) # The recaptcha handler is called with the response given @@ -165,9 +157,6 @@ class UIAuthTests(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - auth_handler = hs.get_auth_handler() - auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs) - self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) self.user_tok = self.login("test", self.user_pass) @@ -177,7 +166,6 @@ class UIAuthTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "devices", access_token=self.user_tok, ) # type: SynapseRequest, FakeChannel - self.render(request) # Get the ID of the device. self.assertEqual(request.code, 200) @@ -190,7 +178,6 @@ class UIAuthTests(unittest.HomeserverTestCase): request, channel = self.make_request( "DELETE", "devices/" + device, body, access_token=self.user_tok ) # type: SynapseRequest, FakeChannel - self.render(request) # Ensure the response is sane. self.assertEqual(request.code, expected_response) @@ -204,7 +191,6 @@ class UIAuthTests(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", "delete_devices", body, access_token=self.user_tok, ) # type: SynapseRequest, FakeChannel - self.render(request) # Ensure the response is sane. self.assertEqual(request.code, expected_response) @@ -240,6 +226,31 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) + def test_grandfathered_identifier(self): + """Check behaviour without "identifier" dict + + Synapse used to require clients to submit a "user" field for m.login.password + UIA - check that still works. + """ + + device_id = self.get_device_ids()[0] + channel = self.delete_device(device_id, 401) + session = channel.json_body["session"] + + # Make another request providing the UI auth flow. + self.delete_device( + device_id, + 200, + { + "auth": { + "type": "m.login.password", + "user": self.user, + "password": self.user_pass, + "session": session, + }, + }, + ) + def test_can_change_body(self): """ The client dict can be modified during the user interactive authentication session. diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index b9e01c9418..767e126875 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -37,7 +37,6 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): def test_check_auth_required(self): request, channel = self.make_request("GET", self.url) - self.render(request) self.assertEqual(channel.code, 401) @@ -46,7 +45,6 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): access_token = self.login("user", "pass") request, channel = self.make_request("GET", self.url, access_token=access_token) - self.render(request) capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) @@ -65,7 +63,6 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): access_token = self.login(user, password) request, channel = self.make_request("GET", self.url, access_token=access_token) - self.render(request) capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) @@ -74,7 +71,6 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertTrue(capabilities["m.change_password"]["enabled"]) self.get_success(self.store.user_set_password_hash(user, None)) request, channel = self.make_request("GET", self.url, access_token=access_token) - self.render(request) capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index de00350580..231d5aefea 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -41,7 +41,6 @@ class FilterTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, ) - self.render(request) self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.json_body, {"filter_id": "0"}) @@ -55,7 +54,6 @@ class FilterTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), self.EXAMPLE_FILTER_JSON, ) - self.render(request) self.assertEqual(channel.result["code"], b"403") self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) @@ -68,7 +66,6 @@ class FilterTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, ) - self.render(request) self.hs.is_mine = _is_mine self.assertEqual(channel.result["code"], b"403") @@ -85,7 +82,6 @@ class FilterTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) ) - self.render(request) self.assertEqual(channel.result["code"], b"200") self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) @@ -94,7 +90,6 @@ class FilterTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) ) - self.render(request) self.assertEqual(channel.result["code"], b"404") self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -105,7 +100,6 @@ class FilterTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) ) - self.render(request) self.assertEqual(channel.result["code"], b"400") @@ -114,6 +108,5 @@ class FilterTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) ) - self.render(request) self.assertEqual(channel.result["code"], b"400") diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
index c57072f50c..ee86b94917 100644 --- a/tests/rest/client/v2_alpha/test_password_policy.py +++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -73,7 +73,6 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/_matrix/client/r0/password_policy" ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self.assertEqual( @@ -91,7 +90,6 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_too_short(self): request_data = json.dumps({"username": "kermit", "password": "shorty"}) request, channel = self.make_request("POST", self.register_url, request_data) - self.render(request) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -101,7 +99,6 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_no_digit(self): request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) request, channel = self.make_request("POST", self.register_url, request_data) - self.render(request) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -111,7 +108,6 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_no_symbol(self): request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) request, channel = self.make_request("POST", self.register_url, request_data) - self.render(request) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -121,7 +117,6 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_no_uppercase(self): request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) request, channel = self.make_request("POST", self.register_url, request_data) - self.render(request) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -131,7 +126,6 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_no_lowercase(self): request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) request, channel = self.make_request("POST", self.register_url, request_data) - self.render(request) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -141,7 +135,6 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_compliant(self): request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) request, channel = self.make_request("POST", self.register_url, request_data) - self.render(request) # Getting a 401 here means the password has passed validation and the server has # responded with a list of registration flows. @@ -173,7 +166,6 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): request_data, access_token=tok, ) - self.render(request) self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 98c3887bbf..b53b47a0cc 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py
@@ -19,8 +19,12 @@ import datetime import json import os +from mock import Mock + import pkg_resources +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import LoginType from synapse.api.errors import Codes @@ -64,7 +68,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) det_data = {"user_id": user_id, "home_server": self.hs.hostname} @@ -76,26 +79,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.render(request) self.assertEquals(channel.result["code"], b"401", channel.result) def test_POST_bad_password(self): request_data = json.dumps({"username": "kermit", "password": 666}) request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.json_body["error"], "Invalid password") - def test_POST_bad_username(self): - request_data = json.dumps({"username": 777, "password": "monkey"}) - request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) - - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid username") - def test_POST_user_valid(self): user_id = "@kermit:test" device_id = "frogfone" @@ -107,7 +100,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): } request_data = json.dumps(params) request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) det_data = { "user_id": user_id, @@ -123,7 +115,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Registration has been disabled") @@ -133,7 +124,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.hs.config.allow_guest_access = True request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.render(request) det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} self.assertEquals(channel.result["code"], b"200", channel.result) @@ -143,7 +133,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.hs.config.allow_guest_access = False request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") @@ -153,7 +142,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): for i in range(0, 6): url = self.url + b"?kind=guest" request, channel = self.make_request(b"POST", url, b"{}") - self.render(request) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -164,7 +152,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.reactor.advance(retry_after_ms / 1000.0 + 1.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -179,7 +166,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): } request_data = json.dumps(params) request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -190,13 +176,11 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.reactor.advance(retry_after_ms / 1000.0 + 1.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) def test_advertised_flows(self): request, channel = self.make_request(b"POST", self.url, b"{}") - self.render(request) self.assertEquals(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -220,7 +204,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) def test_advertised_flows_captcha_and_terms_and_3pids(self): request, channel = self.make_request(b"POST", self.url, b"{}") - self.render(request) self.assertEquals(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -253,7 +236,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) def test_advertised_flows_no_msisdn_email_required(self): request, channel = self.make_request(b"POST", self.url, b"{}") - self.render(request) self.assertEquals(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -298,12 +280,52 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): b"register/email/requestToken", {"client_secret": "foobar", "email": email, "send_attempt": 1}, ) - self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertIsNotNone(channel.json_body.get("sid")) +class RegisterHideProfileTestCase(unittest.HomeserverTestCase): + + servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] + + def make_homeserver(self, reactor, clock): + + self.url = b"/_matrix/client/r0/register" + + config = self.default_config() + config["enable_registration"] = True + config["show_users_in_user_directory"] = False + config["replicate_user_profiles_to"] = ["fakeserver"] + + mock_http_client = Mock(spec=["get_json", "post_json_get_json"]) + mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}")) + + self.hs = self.setup_test_homeserver( + config=config, simple_http_client=mock_http_client + ) + + return self.hs + + def test_profile_hidden(self): + user_id = self.register_user("kermit", "monkey") + + post_json = self.hs.get_simple_http_client().post_json_get_json + + # We expect post_json_get_json to have been called twice: once with the original + # profile and once with the None profile resulting from the request to hide it + # from the user directory. + self.assertEqual(post_json.call_count, 2, post_json.call_args_list) + + # Get the args (and not kwargs) passed to post_json. + args = post_json.call_args[0] + # Make sure the last call was attempting to replicate profiles. + split_uri = args[0].split("/") + self.assertEqual(split_uri[len(split_uri) - 1], "replicate_profiles", args[0]) + # Make sure the last profile update was overriding the user's profile to None. + self.assertEqual(args[1]["batch"][user_id], None, args[1]) + + class AccountValidityTestCase(unittest.HomeserverTestCase): servlets = [ @@ -313,6 +335,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): sync.register_servlets, logout.register_servlets, account_validity.register_servlets, + account.register_servlets, ] def make_homeserver(self, reactor, clock): @@ -334,14 +357,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. request, channel = self.make_request(b"GET", "/sync", access_token=tok) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) request, channel = self.make_request(b"GET", "/sync", access_token=tok) - self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( @@ -360,19 +381,17 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): self.register_user("admin", "adminpassword", admin=True) admin_tok = self.login("admin", "adminpassword") - url = "/_matrix/client/unstable/admin/account_validity/validity" + url = "/_synapse/admin/v1/account_validity/validity" params = {"user_id": user_id} request_data = json.dumps(params) request, channel = self.make_request( b"POST", url, request_data, access_token=admin_tok ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. request, channel = self.make_request(b"GET", "/sync", access_token=tok) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) def test_manual_expire(self): @@ -382,7 +401,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): self.register_user("admin", "adminpassword", admin=True) admin_tok = self.login("admin", "adminpassword") - url = "/_matrix/client/unstable/admin/account_validity/validity" + url = "/_synapse/admin/v1/account_validity/validity" params = { "user_id": user_id, "expiration_ts": 0, @@ -392,13 +411,11 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( b"POST", url, request_data, access_token=admin_tok ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. request, channel = self.make_request(b"GET", "/sync", access_token=tok) - self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result @@ -411,7 +428,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): self.register_user("admin", "adminpassword", admin=True) admin_tok = self.login("admin", "adminpassword") - url = "/_matrix/client/unstable/admin/account_validity/validity" + url = "/_synapse/admin/v1/account_validity/validity" params = { "user_id": user_id, "expiration_ts": 0, @@ -421,12 +438,10 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( b"POST", url, request_data, access_token=admin_tok ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) # Try to log the user out request, channel = self.make_request(b"POST", "/logout", access_token=tok) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) # Log the user in again (allowed for expired accounts) @@ -434,10 +449,155 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): # Try to log out all of the user's sessions request, channel = self.make_request(b"POST", "/logout/all", access_token=tok) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) +class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.client.v1.profile.register_servlets, + synapse.rest.client.v1.room.register_servlets, + synapse.rest.client.v2_alpha.user_directory.register_servlets, + login.register_servlets, + register.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + account_validity.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Set accounts to expire after a week + config["enable_registration"] = True + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + } + config["replicate_user_profiles_to"] = "test.is" + + # Mock homeserver requests to an identity server + mock_http_client = Mock(spec=["post_json_get_json"]) + mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}")) + + self.hs = self.setup_test_homeserver( + config=config, simple_http_client=mock_http_client + ) + + return self.hs + + def test_expired_user_in_directory(self): + """Test that an expired user is hidden in the user directory""" + # Create an admin user to search the user directory + admin_id = self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + # Ensure the admin never expires + url = "/_synapse/admin/v1/account_validity/validity" + params = { + "user_id": admin_id, + "expiration_ts": 999999999999, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + request, channel = self.make_request( + b"POST", url, request_data, access_token=admin_tok + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Mock the homeserver's HTTP client + post_json = self.hs.get_simple_http_client().post_json_get_json + + # Create a user + username = "kermit" + user_id = self.register_user(username, "monkey") + self.login(username, "monkey") + self.get_success( + self.hs.get_datastore().set_profile_displayname(username, "mr.kermit", 1) + ) + + # Check that a full profile for this user is replicated + self.assertIsNotNone(post_json.call_args, post_json.call_args) + payload = post_json.call_args[0][1] + batch = payload.get("batch") + + self.assertIsNotNone(batch, batch) + self.assertEquals(len(batch), 1, batch) + + replicated_user_id = list(batch.keys())[0] + self.assertEquals(replicated_user_id, user_id, replicated_user_id) + + # There was replicated information about our user + # Check that it's not None + replicated_content = batch[user_id] + self.assertIsNotNone(replicated_content) + + # Expire the user + url = "/_synapse/admin/v1/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + request, channel = self.make_request( + b"POST", url, request_data, access_token=admin_tok + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Wait for the background job to run which hides expired users in the directory + self.reactor.advance(60 * 60 * 1000) + + # Check if the homeserver has replicated the user's profile to the identity server + self.assertIsNotNone(post_json.call_args, post_json.call_args) + payload = post_json.call_args[0][1] + batch = payload.get("batch") + + self.assertIsNotNone(batch, batch) + self.assertEquals(len(batch), 1, batch) + + replicated_user_id = list(batch.keys())[0] + self.assertEquals(replicated_user_id, user_id, replicated_user_id) + + # There was replicated information about our user + # Check that it's None, signifying that the user should be removed from the user + # directory because they were expired + replicated_content = batch[user_id] + self.assertIsNone(replicated_content) + + # Now renew the user, and check they get replicated again to the identity server + url = "/_synapse/admin/v1/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 99999999999, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + request, channel = self.make_request( + b"POST", url, request_data, access_token=admin_tok + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.pump(10) + self.reactor.advance(10) + self.pump() + + # Check if the homeserver has replicated the user's profile to the identity server + post_json = self.hs.get_simple_http_client().post_json_get_json + self.assertNotEquals(post_json.call_args, None, post_json.call_args) + payload = post_json.call_args[0][1] + batch = payload.get("batch") + self.assertNotEquals(batch, None, batch) + self.assertEquals(len(batch), 1, batch) + replicated_user_id = list(batch.keys())[0] + self.assertEquals(replicated_user_id, user_id, replicated_user_id) + + # There was replicated information about our user + # Check that it's not None, signifying that the user is back in the user + # directory + replicated_content = batch[user_id] + self.assertIsNotNone(replicated_content) + + class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): servlets = [ @@ -500,8 +660,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): (user_id, tok) = self.create_user() - # Move 6 days forward. This should trigger a renewal email to be sent. - self.reactor.advance(datetime.timedelta(days=6).total_seconds()) + # Move 5 days forward. This should trigger a renewal email to be sent. + self.reactor.advance(datetime.timedelta(days=5).total_seconds()) self.assertEqual(len(self.email_attempts), 1) # Retrieving the URL from the email is too much pain for now, so we @@ -509,18 +669,35 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token request, channel = self.make_request(b"GET", url) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) # Check that we're getting HTML back. - content_type = None - for header in channel.result.get("headers", []): - if header[0] == b"Content-Type": - content_type = header[1] - self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result) + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) # Check that the HTML we're getting is the one we expect on a successful renewal. - expected_html = self.hs.config.account_validity.account_renewed_html_content + expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + expected_html = self.hs.config.account_validity_account_renewed_template.render( + expiration_ts=expiration_ts + ) + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + + # Move 1 day forward. Try to renew with the same token again. + url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token + request, channel = self.make_request(b"GET", url) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Check that we're getting HTML back. + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) + + # Check that the HTML we're getting is the one we expect when reusing a + # token. The account expiration date should not have changed. + expected_html = self.hs.config.account_validity_account_previously_renewed_template.render( + expiration_ts=expiration_ts + ) self.assertEqual( channel.result["body"], expected_html.encode("utf8"), channel.result ) @@ -530,7 +707,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # succeed. self.reactor.advance(datetime.timedelta(days=3).total_seconds()) request, channel = self.make_request(b"GET", "/sync", access_token=tok) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) def test_renewal_invalid_token(self): @@ -538,19 +714,15 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # expected, i.e. that it responds with 404 Not Found and the correct HTML. url = "/_matrix/client/unstable/account_validity/renew?token=123" request, channel = self.make_request(b"GET", url) - self.render(request) self.assertEquals(channel.result["code"], b"404", channel.result) # Check that we're getting HTML back. - content_type = None - for header in channel.result.get("headers", []): - if header[0] == b"Content-Type": - content_type = header[1] - self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result) + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) # Check that the HTML we're getting is the one we expect when using an # invalid/unknown token. - expected_html = self.hs.config.account_validity.invalid_token_html_content + expected_html = self.hs.config.account_validity_invalid_token_template.render() self.assertEqual( channel.result["body"], expected_html.encode("utf8"), channel.result ) @@ -564,7 +736,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -587,7 +758,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) - self.render(request) self.assertEqual(request.code, 200) self.reactor.advance(datetime.timedelta(days=8).total_seconds()) @@ -599,7 +769,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): tok = self.login("kermit", "monkey") # We need to manually add an email address otherwise the handler will do # nothing. - now = self.hs.clock.time_msec() + now = self.hs.get_clock().time_msec() self.get_success( self.store.user_add_threepid( user_id=user_id, @@ -617,7 +787,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # We need to manually add an email address otherwise the handler will do # nothing. - now = self.hs.clock.time_msec() + now = self.hs.get_clock().time_msec() self.get_success( self.store.user_add_threepid( user_id=user_id, @@ -641,7 +811,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -661,7 +830,12 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): config["account_validity"] = {"enabled": False} self.hs = self.setup_test_homeserver(config=config) - self.hs.config.account_validity.period = self.validity_period + + # We need to set these directly, instead of in the homeserver config dict above. + # This is due to account validity-related config options not being read by + # Synapse when account_validity.enabled is False. + self.hs.get_datastore()._account_validity_period = self.validity_period + self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta self.store = self.hs.get_datastore() @@ -675,9 +849,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): """ user_id = self.register_user("kermit_delta", "user") - self.hs.config.account_validity.startup_job_max_delta = self.max_delta - - now_ms = self.hs.clock.time_msec() + now_ms = self.hs.get_clock().time_msec() self.get_success(self.store._set_expiration_date_when_missing()) res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 99c9f4e928..6cd4eb6624 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -65,7 +65,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/event/%s" % (self.room, event_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assert_dict( @@ -114,7 +113,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) # We expect to get back a single pagination result, which is the full @@ -160,7 +158,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id, from_token), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) @@ -219,7 +216,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id, from_token), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) @@ -296,7 +292,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): ), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) @@ -336,7 +331,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEquals( @@ -369,7 +363,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, content={}, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) request, channel = self.make_request( @@ -378,7 +371,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEquals( @@ -396,7 +388,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, self.parent_id, RelationTypes.REPLACE), access_token=self.user_token, ) - self.render(request) self.assertEquals(400, channel.code, channel.json_body) def test_aggregation_get_event(self): @@ -428,7 +419,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEquals( @@ -465,7 +455,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(channel.json_body["content"], new_body) @@ -523,7 +512,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(channel.json_body["content"], new_body) @@ -567,7 +555,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, original_event_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertIn("chunk", channel.json_body) @@ -581,7 +568,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, content="{}", ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) # Try to check for remaining m.replace relations @@ -591,7 +577,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, original_event_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) # Check that no relations are returned @@ -624,7 +609,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, content="{}", ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) # Check that aggregations returns zero @@ -634,7 +618,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): % (self.room, original_event_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertIn("chunk", channel.json_body) @@ -680,7 +663,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): json.dumps(content).encode("utf-8"), access_token=access_token, ) - self.render(request) return channel def _create_user(self, localpart): diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py
index 5ae72fd008..562a9c1ba4 100644 --- a/tests/rest/client/v2_alpha/test_shared_rooms.py +++ b/tests/rest/client/v2_alpha/test_shared_rooms.py
@@ -47,7 +47,6 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): % other_user, access_token=token, ) - self.render(request) return request, channel def test_shared_room_list_public(self): diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index a31e44c97e..31ac0fccb8 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -36,7 +36,6 @@ class FilterTestCase(unittest.HomeserverTestCase): def test_sync_argless(self): request, channel = self.make_request("GET", "/sync") - self.render(request) self.assertEqual(channel.code, 200) self.assertTrue( @@ -57,7 +56,6 @@ class FilterTestCase(unittest.HomeserverTestCase): self.hs.config.use_presence = False request, channel = self.make_request("GET", "/sync") - self.render(request) self.assertEqual(channel.code, 200) self.assertTrue( @@ -199,7 +197,6 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/sync?filter=%s" % sync_filter, access_token=tok ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"] @@ -253,13 +250,11 @@ class SyncTypingTests(unittest.HomeserverTestCase): typing_url % (room, other_user_id, other_access_token), b'{"typing": true, "timeout": 30000}', ) - self.render(request) self.assertEquals(200, channel.code) request, channel = self.make_request( "GET", "/sync?access_token=%s" % (access_token,) ) - self.render(request) self.assertEquals(200, channel.code) next_batch = channel.json_body["next_batch"] @@ -269,7 +264,6 @@ class SyncTypingTests(unittest.HomeserverTestCase): typing_url % (room, other_user_id, other_access_token), b'{"typing": false}', ) - self.render(request) self.assertEquals(200, channel.code) # Start typing. @@ -278,14 +272,12 @@ class SyncTypingTests(unittest.HomeserverTestCase): typing_url % (room, other_user_id, other_access_token), b'{"typing": true, "timeout": 30000}', ) - self.render(request) self.assertEquals(200, channel.code) # Should return immediately request, channel = self.make_request( "GET", sync_url % (access_token, next_batch) ) - self.render(request) self.assertEquals(200, channel.code) next_batch = channel.json_body["next_batch"] @@ -300,7 +292,6 @@ class SyncTypingTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", sync_url % (access_token, next_batch) ) - self.render(request) self.assertEquals(200, channel.code) next_batch = channel.json_body["next_batch"] @@ -311,7 +302,6 @@ class SyncTypingTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", sync_url % (access_token, next_batch) ) - self.render(request) self.assertEquals(200, channel.code) next_batch = channel.json_body["next_batch"] @@ -320,10 +310,8 @@ class SyncTypingTests(unittest.HomeserverTestCase): typing._reset() # Now it SHOULD fail as it never completes! - request, channel = self.make_request( - "GET", sync_url % (access_token, next_batch) - ) - self.assertRaises(TimedOutException, self.render, request) + with self.assertRaises(TimedOutException): + self.make_request("GET", sync_url % (access_token, next_batch)) class UnreadMessagesTestCase(unittest.HomeserverTestCase): @@ -401,7 +389,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): body, access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.json_body) # Check that the unread counter is back to 0. @@ -466,7 +453,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", self.url % self.next_batch, access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.json_body) diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 6850c666be..fbcf8d5b86 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -32,7 +32,7 @@ from synapse.util.httpresourcetree import create_resource_tree from synapse.util.stringutils import random_string from tests import unittest -from tests.server import FakeChannel, wait_until_result +from tests.server import FakeChannel from tests.utils import default_config @@ -41,7 +41,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): self.http_client = Mock() return self.setup_test_homeserver(http_client=self.http_client) - def create_test_json_resource(self): + def create_test_resource(self): return create_resource_tree( {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() ) @@ -94,7 +94,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): % (server_name.encode("utf-8"), key_id.encode("utf-8")), b"1.1", ) - wait_until_result(self.reactor, req) + channel.await_result() self.assertEqual(channel.code, 200) resp = channel.json_body return resp @@ -190,7 +190,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): req.requestReceived( b"POST", path.encode("utf-8"), b"1.1", ) - wait_until_result(self.reactor, req) + channel.await_result() self.assertEqual(channel.code, 200) resp = channel.json_body return resp diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 5f897d49cf..2a3b2a8f27 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py
@@ -36,6 +36,7 @@ from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend from tests import unittest +from tests.server import FakeSite, make_request class MediaStorageTests(unittest.HomeserverTestCase): @@ -227,8 +228,14 @@ class MediaRepoTests(unittest.HomeserverTestCase): def _req(self, content_disposition): - request, channel = self.make_request("GET", self.media_id, shorthand=False) - request.render(self.download_resource) + request, channel = make_request( + self.reactor, + FakeSite(self.download_resource), + "GET", + self.media_id, + shorthand=False, + await_result=False, + ) self.pump() # We've made one fetch, to example.com, using the media URL, and asking @@ -317,10 +324,14 @@ class MediaRepoTests(unittest.HomeserverTestCase): def _test_thumbnail(self, method, expected_body, expected_found): params = "?width=32&height=32&method=" + method - request, channel = self.make_request( - "GET", self.media_id + params, shorthand=False + request, channel = make_request( + self.reactor, + FakeSite(self.thumbnail_resource), + "GET", + self.media_id + params, + shorthand=False, + await_result=False, ) - request.render(self.thumbnail_resource) self.pump() headers = { @@ -348,7 +359,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): channel.json_body, { "errcode": "M_NOT_FOUND", - "error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']" - % method, + "error": "Not found [b'example.com', b'12345']", }, ) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index c00a7b9114..ccdc8c2ecf 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py
@@ -133,13 +133,18 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.reactor.nameResolver = Resolver() + def create_test_resource(self): + return self.hs.get_media_repository_resource() + def test_cache_returns_correct_type(self): self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] request, channel = self.make_request( - "GET", "url_preview?url=http://matrix.org", shorthand=False + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -160,10 +165,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Check the cache returns the correct response request, channel = self.make_request( - "GET", "url_preview?url=http://matrix.org", shorthand=False + "GET", "preview_url?url=http://matrix.org", shorthand=False ) - request.render(self.preview_url) - self.pump() # Check the cache response has the same content self.assertEqual(channel.code, 200) @@ -178,10 +181,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Check the database cache returns the correct response request, channel = self.make_request( - "GET", "url_preview?url=http://matrix.org", shorthand=False + "GET", "preview_url?url=http://matrix.org", shorthand=False ) - request.render(self.preview_url) - self.pump() # Check the cache response has the same content self.assertEqual(channel.code, 200) @@ -201,9 +202,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) request, channel = self.make_request( - "GET", "url_preview?url=http://matrix.org", shorthand=False + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -234,9 +237,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) request, channel = self.make_request( - "GET", "url_preview?url=http://matrix.org", shorthand=False + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -267,9 +272,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) request, channel = self.make_request( - "GET", "url_preview?url=http://matrix.org", shorthand=False + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -298,9 +305,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + "GET", + "preview_url?url=http://example.com", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -326,10 +335,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")] request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + "GET", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() # No requests made. self.assertEqual(len(self.reactor.tcpClients), 0) @@ -349,10 +356,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")] request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + "GET", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() self.assertEqual(channel.code, 502) self.assertEqual( @@ -368,10 +373,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): Blacklisted IP addresses, accessed directly, are not spidered. """ request, channel = self.make_request( - "GET", "url_preview?url=http://192.168.1.1", shorthand=False + "GET", "preview_url?url=http://192.168.1.1", shorthand=False ) - request.render(self.preview_url) - self.pump() # No requests made. self.assertEqual(len(self.reactor.tcpClients), 0) @@ -389,10 +392,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): Blacklisted IP ranges, accessed directly, are not spidered. """ request, channel = self.make_request( - "GET", "url_preview?url=http://1.1.1.2", shorthand=False + "GET", "preview_url?url=http://1.1.1.2", shorthand=False ) - request.render(self.preview_url) - self.pump() self.assertEqual(channel.code, 403) self.assertEqual( @@ -411,9 +412,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")] request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + "GET", + "preview_url?url=http://example.com", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -446,10 +449,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): ] request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + "GET", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() self.assertEqual(channel.code, 502) self.assertEqual( channel.json_body, @@ -468,10 +469,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): ] request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + "GET", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() # No requests made. self.assertEqual(len(self.reactor.tcpClients), 0) @@ -491,10 +490,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["example.com"] = [(IPv6Address, "2001:800::1")] request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + "GET", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() self.assertEqual(channel.code, 502) self.assertEqual( @@ -510,10 +507,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): OPTIONS returns the OPTIONS. """ request, channel = self.make_request( - "OPTIONS", "url_preview?url=http://example.com", shorthand=False + "OPTIONS", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {}) @@ -525,9 +520,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Build and make a request to the server request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + "GET", + "preview_url?url=http://example.com", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() # Extract Synapse's tcp client @@ -598,10 +595,10 @@ class URLPreviewTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", - "url_preview?url=http://twitter.com/matrixdotorg/status/12345", + "preview_url?url=http://twitter.com/matrixdotorg/status/12345", shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -663,10 +660,10 @@ class URLPreviewTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", - "url_preview?url=http://twitter.com/matrixdotorg/status/12345", + "preview_url?url=http://twitter.com/matrixdotorg/status/12345", shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index 2d021f6565..02a46e5fda 100644 --- a/tests/rest/test_health.py +++ b/tests/rest/test_health.py
@@ -20,15 +20,12 @@ from tests import unittest class HealthCheckTests(unittest.HomeserverTestCase): - def setUp(self): - super().setUp() - + def create_test_resource(self): # replace the JsonResource with a HealthResource. - self.resource = HealthResource() + return HealthResource() def test_health(self): request, channel = self.make_request("GET", "/health", shorthand=False) - self.render(request) self.assertEqual(request.code, 200) self.assertEqual(channel.result["body"], b"OK") diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index dcd65c2a50..6a930f4148 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py
@@ -20,11 +20,9 @@ from tests import unittest class WellKnownTests(unittest.HomeserverTestCase): - def setUp(self): - super().setUp() - + def create_test_resource(self): # replace the JsonResource with a WellKnownResource - self.resource = WellKnownResource(self.hs) + return WellKnownResource(self.hs) def test_well_known(self): self.hs.config.public_baseurl = "https://tesths" @@ -33,7 +31,6 @@ class WellKnownTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) - self.render(request) self.assertEqual(request.code, 200) self.assertEqual( @@ -50,6 +47,5 @@ class WellKnownTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) - self.render(request) self.assertEqual(request.code, 404) diff --git a/tests/rulecheck/__init__.py b/tests/rulecheck/__init__.py new file mode 100644
index 0000000000..a354d38ca8 --- /dev/null +++ b/tests/rulecheck/__init__.py
@@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py new file mode 100644
index 0000000000..0937f304cf --- /dev/null +++ b/tests/rulecheck/test_domainrulecheck.py
@@ -0,0 +1,333 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json + +import synapse.rest.admin +from synapse.config._base import ConfigError +from synapse.rest.client.v1 import login, room +from synapse.rulecheck.domain_rule_checker import DomainRuleChecker + +from tests import unittest +from tests.server import make_request + + +class DomainRuleCheckerTestCase(unittest.TestCase): + def test_allowed(self): + config = { + "default": False, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + }, + "domains_prevented_from_being_invited_to_published_rooms": ["target_two"], + } + check = DomainRuleChecker(config) + self.assertTrue( + check.user_may_invite( + "test:source_one", "test:target_one", None, "room", False + ) + ) + self.assertTrue( + check.user_may_invite( + "test:source_one", "test:target_two", None, "room", False + ) + ) + self.assertTrue( + check.user_may_invite( + "test:source_two", "test:target_two", None, "room", False + ) + ) + + # User can invite internal user to a published room + self.assertTrue( + check.user_may_invite( + "test:source_one", "test1:target_one", None, "room", False, True + ) + ) + + # User can invite external user to a non-published room + self.assertTrue( + check.user_may_invite( + "test:source_one", "test:target_two", None, "room", False, False + ) + ) + + def test_disallowed(self): + config = { + "default": True, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + "source_four": [], + }, + } + check = DomainRuleChecker(config) + self.assertFalse( + check.user_may_invite( + "test:source_one", "test:target_three", None, "room", False + ) + ) + self.assertFalse( + check.user_may_invite( + "test:source_two", "test:target_three", None, "room", False + ) + ) + self.assertFalse( + check.user_may_invite( + "test:source_two", "test:target_one", None, "room", False + ) + ) + self.assertFalse( + check.user_may_invite( + "test:source_four", "test:target_one", None, "room", False + ) + ) + + # User cannot invite external user to a published room + self.assertTrue( + check.user_may_invite( + "test:source_one", "test:target_two", None, "room", False, True + ) + ) + + def test_default_allow(self): + config = { + "default": True, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + }, + } + check = DomainRuleChecker(config) + self.assertTrue( + check.user_may_invite( + "test:source_three", "test:target_one", None, "room", False + ) + ) + + def test_default_deny(self): + config = { + "default": False, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + }, + } + check = DomainRuleChecker(config) + self.assertFalse( + check.user_may_invite( + "test:source_three", "test:target_one", None, "room", False + ) + ) + + def test_config_parse(self): + config = { + "default": False, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + }, + } + self.assertEquals(config, DomainRuleChecker.parse_config(config)) + + def test_config_parse_failure(self): + config = { + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + } + } + self.assertRaises(ConfigError, DomainRuleChecker.parse_config, config) + + +class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + hijack_auth = False + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["trusted_third_party_id_servers"] = ["localhost"] + + config["spam_checker"] = { + "module": "synapse.rulecheck.domain_rule_checker.DomainRuleChecker", + "config": { + "default": True, + "domain_mapping": {}, + "can_only_join_rooms_with_invite": True, + "can_only_create_one_to_one_rooms": True, + "can_only_invite_during_room_creation": True, + "can_invite_by_third_party_id": False, + }, + } + + hs = self.setup_test_homeserver(config=config) + return hs + + def prepare(self, reactor, clock, hs): + self.admin_user_id = self.register_user("admin_user", "pass", admin=True) + self.admin_access_token = self.login("admin_user", "pass") + + self.normal_user_id = self.register_user("normal_user", "pass", admin=False) + self.normal_access_token = self.login("normal_user", "pass") + + self.other_user_id = self.register_user("other_user", "pass", admin=False) + + def test_admin_can_create_room(self): + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + def test_normal_user_cannot_create_empty_room(self): + channel = self._create_room(self.normal_access_token) + assert channel.result["code"] == b"403", channel.result + + def test_normal_user_cannot_create_room_with_multiple_invites(self): + channel = self._create_room( + self.normal_access_token, + content={"invite": [self.other_user_id, self.admin_user_id]}, + ) + assert channel.result["code"] == b"403", channel.result + + # Test that it correctly counts both normal and third party invites + channel = self._create_room( + self.normal_access_token, + content={ + "invite": [self.other_user_id], + "invite_3pid": [{"medium": "email", "address": "foo@example.com"}], + }, + ) + assert channel.result["code"] == b"403", channel.result + + # Test that it correctly rejects third party invites + channel = self._create_room( + self.normal_access_token, + content={ + "invite": [], + "invite_3pid": [{"medium": "email", "address": "foo@example.com"}], + }, + ) + assert channel.result["code"] == b"403", channel.result + + def test_normal_user_can_room_with_single_invites(self): + channel = self._create_room( + self.normal_access_token, content={"invite": [self.other_user_id]} + ) + assert channel.result["code"] == b"200", channel.result + + def test_cannot_join_public_room(self): + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + room_id = channel.json_body["room_id"] + + self.helper.join( + room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=403 + ) + + def test_can_join_invited_room(self): + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + room_id = channel.json_body["room_id"] + + self.helper.invite( + room_id, + src=self.admin_user_id, + targ=self.normal_user_id, + tok=self.admin_access_token, + ) + + self.helper.join( + room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200 + ) + + def test_cannot_invite(self): + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + room_id = channel.json_body["room_id"] + + self.helper.invite( + room_id, + src=self.admin_user_id, + targ=self.normal_user_id, + tok=self.admin_access_token, + ) + + self.helper.join( + room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200 + ) + + self.helper.invite( + room_id, + src=self.normal_user_id, + targ=self.other_user_id, + tok=self.normal_access_token, + expect_code=403, + ) + + def test_cannot_3pid_invite(self): + """Test that unbound 3pid invites get rejected. + """ + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + room_id = channel.json_body["room_id"] + + self.helper.invite( + room_id, + src=self.admin_user_id, + targ=self.normal_user_id, + tok=self.admin_access_token, + ) + + self.helper.join( + room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200 + ) + + self.helper.invite( + room_id, + src=self.normal_user_id, + targ=self.other_user_id, + tok=self.normal_access_token, + expect_code=403, + ) + + request, channel = self.make_request( + "POST", + "rooms/%s/invite" % (room_id), + {"address": "foo@bar.com", "medium": "email", "id_server": "localhost"}, + access_token=self.normal_access_token, + ) + self.assertEqual(channel.code, 403, channel.result["body"]) + + def _create_room(self, token, content={}): + path = "/_matrix/client/r0/createRoom?access_token=%s" % (token,) + + request, channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + path, + content=json.dumps(content).encode("utf8"), + ) + + return channel diff --git a/tests/server.py b/tests/server.py
index 3dd2cfc072..a51ad0c14e 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -2,7 +2,7 @@ import json import logging from collections import deque from io import SEEK_END, BytesIO -from typing import Callable +from typing import Callable, Iterable, Optional, Tuple, Union import attr from typing_extensions import Deque @@ -19,8 +19,8 @@ from twisted.internet.interfaces import ( ) from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock -from twisted.web.http import unquote from twisted.web.http_headers import Headers +from twisted.web.resource import IResource from twisted.web.server import Site from synapse.http.site import SynapseRequest @@ -117,6 +117,25 @@ class FakeChannel: def transport(self): return self + def await_result(self, timeout: int = 100) -> None: + """ + Wait until the request is finished. + """ + self._reactor.run() + x = 0 + + while not self.result.get("done"): + # If there's a producer, tell it to resume producing so we get content + if self._producer: + self._producer.resumeProducing() + + x += 1 + + if x > timeout: + raise TimedOutException("Timed out waiting for request to finish.") + + self._reactor.advance(0.1) + class FakeSite: """ @@ -128,9 +147,21 @@ class FakeSite: site_tag = "test" access_logger = logging.getLogger("synapse.access.http.fake") + def __init__(self, resource: IResource): + """ + + Args: + resource: the resource to be used for rendering all requests + """ + self._resource = resource + + def getResourceFor(self, request): + return self._resource + def make_request( reactor, + site: Site, method, path, content=b"", @@ -139,12 +170,19 @@ def make_request( shorthand=True, federation_auth_origin=None, content_is_form=False, + await_result: bool = True, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ): """ - Make a web request using the given method and path, feed it the - content, and return the Request and the Channel underneath. + Make a web request using the given method, path and content, and render it + + Returns the Request and the Channel underneath. Args: + site: The twisted Site to use to render the request + method (bytes/unicode): The HTTP request method ("verb"). path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such). @@ -157,6 +195,12 @@ def make_request( content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. + custom_headers: (name, value) pairs to add as request headers + + await_result: whether to wait for the request to complete rendering. If true, + will pump the reactor until the the renderer tells the channel the request + is finished. + Returns: Tuple[synapse.http.site.SynapseRequest, channel] """ @@ -178,18 +222,17 @@ def make_request( if not path.startswith(b"/"): path = b"/" + path + if isinstance(content, dict): + content = json.dumps(content).encode("utf8") if isinstance(content, str): content = content.encode("utf8") - site = FakeSite() channel = FakeChannel(site, reactor) req = request(channel) - req.process = lambda: b"" req.content = BytesIO(content) # Twisted expects to be at the end of the content when parsing the request. req.content.seek(SEEK_END) - req.postpath = list(map(unquote, path[1:].split(b"/"))) if access_token: req.requestHeaders.addRawHeader( @@ -211,35 +254,16 @@ def make_request( # Assume the body is JSON req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") - req.requestReceived(method, path, b"1.1") - - return req, channel - + if custom_headers: + for k, v in custom_headers: + req.requestHeaders.addRawHeader(k, v) -def wait_until_result(clock, request, timeout=100): - """ - Wait until the request is finished. - """ - clock.run() - x = 0 - - while not request.finished: - - # If there's a producer, tell it to resume producing so we get content - if request._channel._producer: - request._channel._producer.resumeProducing() - - x += 1 - - if x > timeout: - raise TimedOutException("Timed out waiting for request to finish.") - - clock.advance(0.1) + req.requestReceived(method, path, b"1.1") + if await_result: + channel.await_result() -def render(request, resource, clock): - request.render(resource) - wait_until_result(clock, request) + return req, channel @implementer(IReactorPluggableNameResolver) diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index 872039c8f1..e0a9cd93ac 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py
@@ -73,7 +73,6 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/_matrix/client/r0/sync", access_token=self.access_token ) - self.render(request) self.assertEqual(channel.code, 200) # Get the Room ID to join @@ -85,14 +84,12 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): "/_matrix/client/r0/rooms/" + room_id + "/join", access_token=self.access_token, ) - self.render(request) self.assertEqual(channel.code, 200) # Sync again, to get the message in the room request, channel = self.make_request( "GET", "/_matrix/client/r0/sync", access_token=self.access_token ) - self.render(request) self.assertEqual(channel.code, 200) # Get the message diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 6382b19dc3..9c8027a5b2 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -306,7 +306,6 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): tok = self.login("user", "password") request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) - self.render(request) invites = channel.json_body["rooms"]["invite"] self.assertEqual(len(invites), 0, invites) @@ -320,7 +319,6 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): # Sync again to retrieve the events in the room, so we can check whether this # room has a notice in it. request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) - self.render(request) # Scan the events in the room to search for a message from the server notices # user. @@ -358,7 +356,6 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): request, channel = self.make_request( "GET", "/sync?timeout=0", access_token=tok, ) - self.render(request) # Also retrieves the list of invites for this user. We don't care about that # one except if we're processing the last user, which should have received an diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 5a1e5c4e66..c13a57dad1 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py
@@ -309,36 +309,6 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): ) self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) - @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0) - def test_send_dummy_event_without_consent(self): - self._create_extremity_rich_graph() - self._enable_consent_checking() - - # Pump the reactor repeatedly so that the background updates have a - # chance to run. Attempt to add dummy event with user that has not consented - # Check that dummy event send fails. - self.pump(10 * 60) - latest_event_ids = self.get_success( - self.store.get_latest_event_ids_in_room(self.room_id) - ) - self.assertTrue(len(latest_event_ids) == self.EXTREMITIES_COUNT) - - # Create new user, and add consent - user2 = self.register_user("user2", "password") - token2 = self.login("user2", "password") - self.get_success( - self.store.user_set_consent_version(user2, self.CONSENT_VERSION) - ) - self.helper.join(self.room_id, user2, tok=token2) - - # Background updates should now cause a dummy event to be added to the graph - self.pump(10 * 60) - - latest_event_ids = self.get_success( - self.store.get_latest_event_ids_in_room(self.room_id) - ) - self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) - @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250) def test_expiry_logic(self): """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index e96ca1c8ca..a69117c5a9 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py
@@ -21,6 +21,7 @@ from synapse.http.site import XForwardedForRequest from synapse.rest.client.v1 import login from tests import unittest +from tests.server import make_request from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -408,18 +409,18 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): # Advance to a known time self.reactor.advance(123456 - self.reactor.seconds()) - request, channel = self.make_request( + headers1 = {b"User-Agent": b"Mozzila pizza"} + headers1.update(headers) + + make_request( + self.reactor, + self.site, "GET", - "/_matrix/client/r0/admin/users/" + self.user_id, + "/_synapse/admin/v1/users/" + self.user_id, access_token=access_token, + custom_headers=headers1.items(), **make_request_args, ) - request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza") - - # Add the optional headers - for h, v in headers.items(): - request.requestHeaders.addRawHeader(h, v) - self.render(request) # Advance so the save loop occurs self.reactor.advance(100) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 7e7f1286d9..fe37d2ed5a 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py
@@ -39,7 +39,7 @@ class DataStoreTestCase(unittest.TestCase): ) yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) yield defer.ensureDeferred( - self.store.set_profile_displayname(self.user.localpart, self.displayname) + self.store.set_profile_displayname(self.user.localpart, self.displayname, 1) ) users, total = yield defer.ensureDeferred( diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 3fd0a38cf5..7a38022e71 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py
@@ -36,7 +36,7 @@ class ProfileStoreTestCase(unittest.TestCase): yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) yield defer.ensureDeferred( - self.store.set_profile_displayname(self.u_frank.localpart, "Frank") + self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1) ) self.assertEquals( @@ -54,7 +54,7 @@ class ProfileStoreTestCase(unittest.TestCase): yield defer.ensureDeferred( self.store.set_profile_avatar_url( - self.u_frank.localpart, "http://my.site/here" + self.u_frank.localpart, "http://my.site/here", 1 ) ) diff --git a/tests/test_mau.py b/tests/test_mau.py
index 654a6fa42d..c5ec6396a7 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py
@@ -202,7 +202,6 @@ class TestMauLimit(unittest.HomeserverTestCase): ) request, channel = self.make_request("POST", "/register", request_data) - self.render(request) if channel.code != 200: raise HttpResponseException( @@ -215,7 +214,6 @@ class TestMauLimit(unittest.HomeserverTestCase): def do_sync_for_user(self, token): request, channel = self.make_request("GET", "/sync", access_token=token) - self.render(request) if channel.code != 200: raise HttpResponseException( diff --git a/tests/test_server.py b/tests/test_server.py
index 655c918a15..c387a85f2e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py
@@ -26,9 +26,9 @@ from synapse.util import Clock from tests import unittest from tests.server import ( + FakeSite, ThreadedMemoryReactorClock, make_request, - render, setup_test_homeserver, ) @@ -62,9 +62,8 @@ class JsonResourceTests(unittest.TestCase): ) request, channel = make_request( - self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" + self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" ) - render(request, res, self.reactor) self.assertEqual(request.args, {b"a": ["\N{SNOWMAN}".encode("utf8")]}) self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"}) @@ -83,8 +82,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"500") @@ -108,8 +106,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"500") @@ -127,8 +124,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["error"], "Forbidden!!one!") @@ -150,8 +146,9 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar") - render(request, res, self.reactor) + _, channel = make_request( + self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar" + ) self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.json_body["error"], "Unrecognized request") @@ -173,8 +170,7 @@ class JsonResourceTests(unittest.TestCase): ) # The path was registered as GET, but this is a HEAD request. - request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo") - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) @@ -196,9 +192,6 @@ class OptionsResourceTests(unittest.TestCase): def _make_request(self, method, path): """Create a request from the method/path and return a channel with the response.""" - request, channel = make_request(self.reactor, method, path, shorthand=False) - request.prepath = [] # This doesn't get set properly by make_request. - # Create a site and query for the resource. site = SynapseSite( "test", @@ -207,11 +200,9 @@ class OptionsResourceTests(unittest.TestCase): self.resource, "1.0", ) - request.site = site - resource = site.getResourceFor(request) - # Finally, render the resource and return the channel. - render(request, resource, self.reactor) + # render the request and return the channel + _, channel = make_request(self.reactor, site, method, path, shorthand=False) return channel def test_unknown_options_request(self): @@ -284,8 +275,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, b"GET", b"/path") - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"200") body = channel.result["body"] @@ -303,8 +293,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, b"GET", b"/path") - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"301") headers = channel.result["headers"] @@ -325,8 +314,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, b"GET", b"/path") - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"304") headers = channel.result["headers"] @@ -345,8 +333,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, b"HEAD", b"/path") - render(request, res, self.reactor) + _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path") self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) diff --git a/tests/test_state.py b/tests/test_state.py
index 80b0ccbc40..6227a3ba95 100644 --- a/tests/test_state.py +++ b/tests/test_state.py
@@ -169,6 +169,7 @@ class StateTestCase(unittest.TestCase): "get_state_handler", "get_clock", "get_state_resolution_handler", + "hostname", ] ) hs.config = default_config("tesths", True) diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index b89798336c..71580b454d 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py
@@ -54,7 +54,6 @@ class TermsTestCase(unittest.HomeserverTestCase): # Do a UI auth request request_data = json.dumps({"username": "kermit", "password": "monkey"}) request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) self.assertEquals(channel.result["code"], b"401", channel.result) @@ -98,7 +97,6 @@ class TermsTestCase(unittest.HomeserverTestCase): self.registration_handler.check_username = Mock(return_value=True) request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) # We don't bother checking that the response is correct - we'll leave that to # other tests. We just want to make sure we're on the right path. @@ -116,7 +114,6 @@ class TermsTestCase(unittest.HomeserverTestCase): } ) request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) # We're interested in getting a response that looks like a successful # registration, not so much that the details are exactly what we want. diff --git a/tests/test_types.py b/tests/test_types.py
index 480bea1bdc..d4a722a30f 100644 --- a/tests/test_types.py +++ b/tests/test_types.py
@@ -12,9 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from six import string_types from synapse.api.errors import SynapseError -from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart +from synapse.types import ( + GroupID, + RoomAlias, + UserID, + map_username_to_mxid_localpart, + strip_invalid_mxid_characters, +) from tests import unittest @@ -103,3 +110,16 @@ class MapUsernameTestCase(unittest.TestCase): self.assertEqual( map_username_to_mxid_localpart("têst".encode("utf-8")), "t=c3=aast" ) + + +class StripInvalidMxidCharactersTestCase(unittest.TestCase): + def test_return_type(self): + unstripped = strip_invalid_mxid_characters("test") + stripped = strip_invalid_mxid_characters("test@") + + self.assertTrue(isinstance(unstripped, string_types), type(unstripped)) + self.assertTrue(isinstance(stripped, string_types), type(stripped)) + + def test_strip(self): + stripped = strip_invalid_mxid_characters("test@") + self.assertEqual(stripped, "test", stripped) diff --git a/tests/unittest.py b/tests/unittest.py
index e36ac89196..a9d59e31f7 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -30,6 +30,7 @@ from twisted.internet.defer import Deferred, ensureDeferred, succeed from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool from twisted.trial import unittest +from twisted.web.resource import Resource from synapse.api.constants import EventTypes, Membership from synapse.config.homeserver import HomeServerConfig @@ -47,13 +48,7 @@ from synapse.server import HomeServer from synapse.types import UserID, create_requester from synapse.util.ratelimitutils import FederationRateLimiter -from tests.server import ( - FakeChannel, - get_clock, - make_request, - render, - setup_test_homeserver, -) +from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -239,10 +234,8 @@ class HomeserverTestCase(TestCase): if not isinstance(self.hs, HomeServer): raise Exception("A homeserver wasn't returned, but %r" % (self.hs,)) - # Register the resources - self.resource = self.create_test_json_resource() - - # create a site to wrap the resource. + # create the root resource, and a site to wrap it. + self.resource = self.create_test_resource() self.site = SynapseSite( logger_name="synapse.access.http.fake", site_tag=self.hs.config.server.server_name, @@ -253,7 +246,7 @@ class HomeserverTestCase(TestCase): from tests.rest.client.v1.utils import RestHelper - self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None)) + self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None)) if hasattr(self, "user_id"): if self.hijack_auth: @@ -323,15 +316,12 @@ class HomeserverTestCase(TestCase): hs = self.setup_test_homeserver() return hs - def create_test_json_resource(self): + def create_test_resource(self) -> Resource: """ - Create a test JsonResource, with the relevant servlets registerd to it - - The default implementation calls each function in `servlets` to do the - registration. + Create a the root resource for the test server. - Returns: - JsonResource: + The default implementation creates a JsonResource and calls each function in + `servlets` to register servletes against it """ resource = JsonResource(self.hs) @@ -381,6 +371,7 @@ class HomeserverTestCase(TestCase): shorthand: bool = True, federation_auth_origin: str = None, content_is_form: bool = False, + await_result: bool = True, ) -> Tuple[SynapseRequest, FakeChannel]: ... @@ -395,6 +386,7 @@ class HomeserverTestCase(TestCase): shorthand: bool = True, federation_auth_origin: str = None, content_is_form: bool = False, + await_result: bool = True, ) -> Tuple[T, FakeChannel]: ... @@ -408,6 +400,7 @@ class HomeserverTestCase(TestCase): shorthand: bool = True, federation_auth_origin: str = None, content_is_form: bool = False, + await_result: bool = True, ) -> Tuple[T, FakeChannel]: """ Create a SynapseRequest at the path using the method and containing the @@ -426,14 +419,16 @@ class HomeserverTestCase(TestCase): content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. + await_result: whether to wait for the request to complete rendering. If + true (the default), will pump the test reactor until the the renderer + tells the channel the request is finished. + Returns: Tuple[synapse.http.site.SynapseRequest, channel] """ - if isinstance(content, dict): - content = json.dumps(content).encode("utf8") - return make_request( self.reactor, + self.site, method, path, content, @@ -442,18 +437,9 @@ class HomeserverTestCase(TestCase): shorthand, federation_auth_origin, content_is_form, + await_result, ) - def render(self, request): - """ - Render a request against the resources registered by the test class's - servlets. - - Args: - request (synapse.http.site.SynapseRequest): The request to render. - """ - render(request, self.resource, self.reactor) - def setup_test_homeserver(self, *args, **kwargs): """ Set up the test homeserver, meant to be called by the overridable @@ -568,8 +554,7 @@ class HomeserverTestCase(TestCase): self.hs.config.registration_shared_secret = "shared" # Create the user - request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register") - self.render(request) + request, channel = self.make_request("GET", "/_synapse/admin/v1/register") self.assertEqual(channel.code, 200, msg=channel.result) nonce = channel.json_body["nonce"] @@ -595,9 +580,8 @@ class HomeserverTestCase(TestCase): } ) request, channel = self.make_request( - "POST", "/_matrix/client/r0/admin/register", body.encode("utf8") + "POST", "/_synapse/admin/v1/register", body.encode("utf8") ) - self.render(request) self.assertEqual(channel.code, 200, channel.json_body) user_id = channel.json_body["user_id"] @@ -616,7 +600,6 @@ class HomeserverTestCase(TestCase): request, channel = self.make_request( "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) access_token = channel.json_body["access_token"] @@ -685,7 +668,6 @@ class HomeserverTestCase(TestCase): request, channel = self.make_request( "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) - self.render(request) self.assertEqual(channel.code, 403, channel.result) def inject_room_member(self, room: str, user: str, membership: Membership) -> None: diff --git a/tests/utils.py b/tests/utils.py
index acec74e9e9..1584eacb12 100644 --- a/tests/utils.py +++ b/tests/utils.py
@@ -176,6 +176,8 @@ def default_config(name, parse=False): "update_user_directory": False, "caches": {"global_factor": 1}, "listeners": [{"port": 0, "type": "http"}], + # Enable encryption by default in private rooms + "encryption_enabled_by_default_for_room_type": "invite", } if parse: @@ -271,7 +273,7 @@ def setup_test_homeserver( # Install @cache_in_self attributes for key, val in kwargs.items(): - setattr(hs, key, val) + setattr(hs, "_" + key, val) # Mock TLS hs.tls_server_context_factory = Mock()