diff options
Diffstat (limited to 'tests')
71 files changed, 2172 insertions, 1238 deletions
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index 43fef5d64a..e0ca288829 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -57,7 +57,7 @@ class FrontendProxyTests(HomeserverTestCase): self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] - _, channel = make_request(self.reactor, site, "PUT", "presence/a/status") + channel = make_request(self.reactor, site, "PUT", "presence/a/status") # 400 + unrecognised, because nothing is registered self.assertEqual(channel.code, 400) @@ -77,7 +77,7 @@ class FrontendProxyTests(HomeserverTestCase): self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] - _, channel = make_request(self.reactor, site, "PUT", "presence/a/status") + channel = make_request(self.reactor, site, "PUT", "presence/a/status") # 401, because the stub servlet still checks authentication self.assertEqual(channel.code, 401) diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index b260ab734d..467033e201 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -73,7 +73,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): return raise - _, channel = make_request( + channel = make_request( self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo" ) @@ -121,7 +121,7 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): return raise - _, channel = make_request( + channel = make_request( self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo" ) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index d146f2254f..1d65ea2f9c 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -75,7 +75,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): return val def test_verify_json_objects_for_server_awaits_previous_requests(self): - mock_fetcher = keyring.KeyFetcher() + mock_fetcher = Mock() mock_fetcher.get_keys = Mock() kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) @@ -195,7 +195,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): """Tests that we correctly handle key requests for keys we've stored with a null `ts_valid_until_ms` """ - mock_fetcher = keyring.KeyFetcher() + mock_fetcher = Mock() mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) kr = keyring.Keyring( @@ -249,7 +249,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): } } - mock_fetcher = keyring.KeyFetcher() + mock_fetcher = Mock() mock_fetcher.get_keys = Mock(side_effect=get_keys) kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) @@ -288,9 +288,9 @@ class KeyringTestCase(unittest.HomeserverTestCase): } } - mock_fetcher1 = keyring.KeyFetcher() + mock_fetcher1 = Mock() mock_fetcher1.get_keys = Mock(side_effect=get_keys1) - mock_fetcher2 = keyring.KeyFetcher() + mock_fetcher2 = Mock() mock_fetcher2.get_keys = Mock(side_effect=get_keys2) kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2)) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 0187f56e21..9ccd2d76b8 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -48,7 +48,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): ) # Get the room complexity - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) self.assertEquals(200, channel.code) @@ -60,7 +60,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23) # Get the room complexity again -- make sure it's our artificial value - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) self.assertEquals(200, channel.code) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 3009fbb6c4..cfeccc0577 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -46,7 +46,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase): "/get_missing_events/(?P<room_id>[^/]*)/?" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/federation/v1/get_missing_events/%s" % (room_1,), query_content, @@ -95,7 +95,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): room_1 = self.helper.create_room_as(u1, tok=u1_token) self.inject_room_member(room_1, "@user:other.example.com", "join") - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) self.assertEquals(200, channel.code, channel.result) @@ -127,7 +127,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): room_1 = self.helper.create_room_as(u1, tok=u1_token) - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) self.assertEquals(403, channel.code, channel.result) diff --git a/tests/federation/transport/__init__.py b/tests/federation/transport/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/federation/transport/__init__.py diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py index f9e3c7a51f..85500e169c 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,38 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer - -from synapse.config.ratelimiting import FederationRateLimitConfig -from synapse.federation.transport import server -from synapse.util.ratelimitutils import FederationRateLimiter - from tests import unittest from tests.unittest import override_config -class RoomDirectoryFederationTests(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): - class Authenticator: - def authenticate_request(self, request, content): - return defer.succeed("otherserver.nottld") - - ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig()) - server.register_servlets( - homeserver, self.resource, Authenticator(), ratelimiter - ) - +class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): @override_config({"allow_public_rooms_over_federation": False}) def test_blocked_public_room_list_over_federation(self): - request, channel = self.make_request( - "GET", "/_matrix/federation/v1/publicRooms" + """Test that unauthenticated requests to the public rooms directory 403 when + allow_public_rooms_over_federation is False. + """ + channel = self.make_request( + "GET", + "/_matrix/federation/v1/publicRooms", + federation_auth_origin=b"example.com", ) self.assertEquals(403, channel.code) @override_config({"allow_public_rooms_over_federation": True}) def test_open_public_room_list_over_federation(self): - request, channel = self.make_request( - "GET", "/_matrix/federation/v1/publicRooms" + """Test that unauthenticated requests to the public rooms directory 200 when + allow_public_rooms_over_federation is True. + """ + channel = self.make_request( + "GET", + "/_matrix/federation/v1/publicRooms", + federation_auth_origin=b"example.com", ) self.assertEquals(200, channel.code) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py new file mode 100644 index 0000000000..bd7a1b6891 --- /dev/null +++ b/tests/handlers/test_cas.py @@ -0,0 +1,121 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mock import Mock + +from synapse.handlers.cas_handler import CasResponse + +from tests.test_utils import simple_async_mock +from tests.unittest import HomeserverTestCase + +# These are a few constants that are used as config parameters in the tests. +BASE_URL = "https://synapse/" +SERVER_URL = "https://issuer/" + + +class CasHandlerTestCase(HomeserverTestCase): + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + cas_config = { + "enabled": True, + "server_url": SERVER_URL, + "service_url": BASE_URL, + } + config["cas_config"] = cas_config + + return config + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + + self.handler = hs.get_cas_handler() + + # Reduce the number of attempts when generating MXIDs. + sso_handler = hs.get_sso_handler() + sso_handler._MAP_USERNAME_RETRIES = 3 + + return hs + + def test_map_cas_user_to_user(self): + """Ensure that mapping the CAS user returned from a provider to an MXID works properly.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + cas_response = CasResponse("test_user", {}) + request = _mock_request() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None + ) + + def test_map_cas_user_to_existing_user(self): + """Existing users can log in with CAS account.""" + store = self.hs.get_datastore() + self.get_success( + store.register_user(user_id="@test_user:test", password_hash=None) + ) + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # Map a user via SSO. + cas_response = CasResponse("test_user", {}) + request = _mock_request() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None + ) + + # Subsequent calls should map to the same mxid. + auth_handler.complete_sso_login.reset_mock() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None + ) + + def test_map_cas_user_to_invalid_localpart(self): + """CAS automaps invalid characters to base-64 encoding.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + cas_response = CasResponse("föö", {}) + request = _mock_request() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@f=c3=b6=c3=b6:test", request, "redirect_uri", None + ) + + +def _mock_request(): + """Returns a mock which will stand in as a SynapseRequest""" + return Mock(spec=["getClientIP", "get_user_agent"]) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 770d225ed5..a39f898608 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -405,7 +405,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): def test_denied(self): room_id = self.helper.create_room_as(self.user_id) - request, channel = self.make_request( + channel = self.make_request( "PUT", b"directory/room/%23test%3Atest", ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), @@ -415,7 +415,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): def test_allowed(self): room_id = self.helper.create_room_as(self.user_id) - request, channel = self.make_request( + channel = self.make_request( "PUT", b"directory/room/%23unofficial_test%3Atest", ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), @@ -431,7 +431,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): room_id = self.helper.create_room_as(self.user_id) - request, channel = self.make_request( + channel = self.make_request( "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) self.assertEquals(200, channel.code, channel.result) @@ -446,7 +446,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): self.directory_handler.enable_room_list_search = True # Room list is enabled so we should get some results - request, channel = self.make_request("GET", b"publicRooms") + channel = self.make_request("GET", b"publicRooms") self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) > 0) @@ -454,13 +454,13 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): self.directory_handler.enable_room_list_search = False # Room list disabled so we should get no results - request, channel = self.make_request("GET", b"publicRooms") + channel = self.make_request("GET", b"publicRooms") self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) == 0) # Room list disabled so we shouldn't be allowed to publish rooms room_id = self.helper.create_room_as(self.user_id) - request, channel = self.make_request( + channel = self.make_request( "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) self.assertEquals(403, channel.code, channel.result) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index af42775815..f955dfa490 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -206,7 +206,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase): # Redaction of event should fail. path = "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room_id, event_id) - request, channel = self.make_request( + channel = self.make_request( "POST", path, content={}, access_token=self.access_token ) self.assertEqual(int(channel.result["code"]), 403) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 9878527bab..368d600b33 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,17 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from urllib.parse import parse_qs, urlparse +import re +from typing import Dict +from urllib.parse import parse_qs, urlencode, urlparse from mock import ANY, Mock, patch import pymacaroons -from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider +from twisted.web.resource import Resource + +from synapse.api.errors import RedirectException +from synapse.handlers.oidc_handler import OidcError from synapse.handlers.sso import MappingException +from synapse.rest.client.v1 import login +from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.server import HomeServer from synapse.types import UserID -from tests.test_utils import FakeResponse +from tests.test_utils import FakeResponse, simple_async_mock from tests.unittest import HomeserverTestCase, override_config # These are a few constants that are used as config parameters in the tests. @@ -55,11 +63,14 @@ COOKIE_NAME = b"oidc_session" COOKIE_PATH = "/_synapse/oidc" -class TestMappingProvider(OidcMappingProvider): +class TestMappingProvider: @staticmethod def parse_config(config): return + def __init__(self, config): + pass + def get_remote_user_id(self, userinfo): return userinfo["sub"] @@ -82,16 +93,6 @@ class TestMappingProviderFailures(TestMappingProvider): } -def simple_async_mock(return_value=None, raises=None) -> Mock: - # AsyncMock is not available in python3.5, this mimics part of its behaviour - async def cb(*args, **kwargs): - if raises: - raise raises - return return_value - - return Mock(side_effect=cb) - - async def get_json(url): # Mock get_json calls to handle jwks & oidc discovery endpoints if url == WELL_KNOWN: @@ -370,6 +371,13 @@ class OidcHandlerTestCase(HomeserverTestCase): - when the userinfo fetching fails - when the code exchange fails """ + + # ensure that we are correctly testing the fallback when "get_extra_attributes" + # is not implemented. + mapping_provider = self.handler._user_mapping_provider + with self.assertRaises(AttributeError): + _ = mapping_provider.get_extra_attributes + token = { "type": "bearer", "id_token": "id_token", @@ -399,14 +407,14 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url=client_redirect_url, ui_auth_session_id=None, ) - request = self._build_callback_request( + request = _build_callback_request( code, state, session, user_agent=user_agent, ip_address=ip_address ) self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, request, client_redirect_url, {}, + expected_user_id, request, client_redirect_url, None, ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -437,7 +445,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, request, client_redirect_url, {}, + expected_user_id, request, client_redirect_url, None, ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() @@ -607,7 +615,7 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url=client_redirect_url, ui_auth_session_id=None, ) - request = self._build_callback_request("code", state, session) + request = _build_callback_request("code", state, session) self.get_success(self.handler.handle_oidc_callback(request)) @@ -624,9 +632,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test_user", "username": "test_user", } - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", ANY, ANY, {} + "@test_user:test", ANY, ANY, None, ) auth_handler.complete_sso_login.reset_mock() @@ -635,9 +643,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": 1234, "username": "test_user_2", } - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user_2:test", ANY, ANY, {} + "@test_user_2:test", ANY, ANY, None, ) auth_handler.complete_sso_login.reset_mock() @@ -648,7 +656,7 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user3.to_string(), password_hash=None) ) userinfo = {"sub": "test3", "username": "test_user_3"} - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", @@ -672,16 +680,16 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test", "username": "test_user", } - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, {}, + user.to_string(), ANY, ANY, None, ) auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, {}, + user.to_string(), ANY, ANY, None, ) auth_handler.complete_sso_login.reset_mock() @@ -694,9 +702,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test1", "username": "test_user", } - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, {}, + user.to_string(), ANY, ANY, None, ) auth_handler.complete_sso_login.reset_mock() @@ -715,7 +723,7 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test2", "username": "TEST_USER_2", } - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_not_called() args = self.assertRenderedError("mapping_error") self.assertTrue( @@ -730,14 +738,16 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user2.to_string(), password_hash=None) ) - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@TEST_USER_2:test", ANY, ANY, {}, + "@TEST_USER_2:test", ANY, ANY, None, ) def test_map_userinfo_to_invalid_localpart(self): """If the mapping provider generates an invalid localpart it should be rejected.""" - self._make_callback_with_userinfo({"sub": "test2", "username": "föö"}) + self.get_success( + _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) + ) self.assertRenderedError("mapping_error", "localpart is invalid: föö") @override_config( @@ -762,11 +772,11 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test", "username": "test_user", } - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", ANY, ANY, {}, + "@test_user1:test", ANY, ANY, None, ) auth_handler.complete_sso_login.reset_mock() @@ -784,70 +794,214 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "tester", } - self._make_callback_with_userinfo(userinfo) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_not_called() self.assertRenderedError( "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) - def _make_callback_with_userinfo( - self, userinfo: dict, client_redirect_url: str = "http://client/redirect" - ) -> None: - self.handler._exchange_code = simple_async_mock(return_value={}) - self.handler._parse_id_token = simple_async_mock(return_value=userinfo) - self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + def test_empty_localpart(self): + """Attempts to map onto an empty localpart should be rejected.""" + userinfo = { + "sub": "tester", + "username": "", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + self.assertRenderedError("mapping_error", "localpart is invalid: ") - state = "state" - session = self.handler._generate_oidc_session_token( - state=state, - nonce="nonce", - client_redirect_url=client_redirect_url, - ui_auth_session_id=None, - ) - request = self._build_callback_request("code", state, session) + @override_config( + { + "oidc_config": { + "user_mapping_provider": { + "config": {"localpart_template": "{{ user.username }}"} + } + } + } + ) + def test_null_localpart(self): + """Mapping onto a null localpart via an empty OIDC attribute should be rejected""" + userinfo = { + "sub": "tester", + "username": None, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + self.assertRenderedError("mapping_error", "localpart is invalid: ") - self.get_success(self.handler.handle_oidc_callback(request)) - def _build_callback_request( - self, - code: str, - state: str, - session: str, - user_agent: str = "Browser", - ip_address: str = "10.0.0.1", - ): - """Builds a fake SynapseRequest to mock the browser callback - - Returns a Mock object which looks like the SynapseRequest we get from a browser - after SSO (before we return to the client) - - Args: - code: the authorization code which would have been returned by the OIDC - provider - state: the "state" param which would have been passed around in the - query param. Should be the same as was embedded in the session in - _build_oidc_session. - session: the "session" which would have been passed around in the cookie. - user_agent: the user-agent to present - ip_address: the IP address to pretend the request came from - """ - request = Mock( - spec=[ - "args", - "getCookie", - "addCookie", - "requestHeaders", - "getClientIP", - "get_user_agent", - ] - ) +class UsernamePickerTestCase(HomeserverTestCase): + servlets = [login.register_servlets] - request.getCookie.return_value = session - request.args = {} - request.args[b"code"] = [code.encode("utf-8")] - request.args[b"state"] = [state.encode("utf-8")] - request.getClientIP.return_value = ip_address - request.get_user_agent.return_value = user_agent - return request + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + oidc_config = { + "enabled": True, + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "issuer": ISSUER, + "scopes": SCOPES, + "user_mapping_provider": { + "config": {"display_name_template": "{{ user.displayname }}"} + }, + } + + # Update this config with what's in the default config so that + # override_config works as expected. + oidc_config.update(config.get("oidc_config", {})) + config["oidc_config"] = oidc_config + + # whitelist this client URI so we redirect straight to it rather than + # serving a confirmation page + config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) + return d + + def test_username_picker(self): + """Test the happy path of a username picker flow.""" + client_redirect_url = "https://whitelisted.client" + + # first of all, mock up an OIDC callback to the OidcHandler, which should + # raise a RedirectException + userinfo = {"sub": "tester", "displayname": "Jonny"} + f = self.get_failure( + _make_callback_with_userinfo( + self.hs, userinfo, client_redirect_url=client_redirect_url + ), + RedirectException, + ) + + # check the Location and cookies returned by the RedirectException + self.assertEqual(f.value.location, b"/_synapse/client/pick_username") + cookieheader = f.value.cookies[0] + regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);") + m = regex.search(cookieheader) + if not m: + self.fail("cookie header %s does not match %s" % (cookieheader, regex)) + + # introspect the sso handler a bit to check that the username mapping session + # looks ok. + session_id = m.group(1).decode("ascii") + username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions + self.assertIn( + session_id, username_mapping_sessions, "session id not found in map" + ) + session = username_mapping_sessions[session_id] + self.assertEqual(session.remote_user_id, "tester") + self.assertEqual(session.display_name, "Jonny") + self.assertEqual(session.client_redirect_url, client_redirect_url) + + # the expiry time should be about 15 minutes away + expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) + self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) + + # Now, submit a username to the username picker, which should serve a redirect + # back to the client + submit_path = f.value.location + b"/submit" + content = urlencode({b"username": b"bobby"}).encode("utf8") + chan = self.make_request( + "POST", + path=submit_path, + content=content, + content_is_form=True, + custom_headers=[ + ("Cookie", cookieheader), + # old versions of twisted don't do form-parsing without a valid + # content-length header. + ("Content-Length", str(len(content))), + ], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + # ensure that the returned location starts with the requested redirect URL + self.assertEqual( + location_headers[0][: len(client_redirect_url)], client_redirect_url + ) + + # fish the login token out of the returned redirect uri + parts = urlparse(location_headers[0]) + query = parse_qs(parts.query) + login_token = query["loginToken"][0] + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + chan = self.make_request( + "POST", "/login", content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@bobby:test") + + +async def _make_callback_with_userinfo( + hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" +) -> None: + """Mock up an OIDC callback with the given userinfo dict + + We'll pull out the OIDC handler from the homeserver, stub out a couple of methods, + and poke in the userinfo dict as if it were the response to an OIDC userinfo call. + + Args: + hs: the HomeServer impl to send the callback to. + userinfo: the OIDC userinfo dict + client_redirect_url: the URL to redirect to on success. + """ + handler = hs.get_oidc_handler() + handler._exchange_code = simple_async_mock(return_value={}) + handler._parse_id_token = simple_async_mock(return_value=userinfo) + handler._fetch_userinfo = simple_async_mock(return_value=userinfo) + + state = "state" + session = handler._generate_oidc_session_token( + state=state, + nonce="nonce", + client_redirect_url=client_redirect_url, + ui_auth_session_id=None, + ) + request = _build_callback_request("code", state, session) + + await handler.handle_oidc_callback(request) + + +def _build_callback_request( + code: str, + state: str, + session: str, + user_agent: str = "Browser", + ip_address: str = "10.0.0.1", +): + """Builds a fake SynapseRequest to mock the browser callback + + Returns a Mock object which looks like the SynapseRequest we get from a browser + after SSO (before we return to the client) + + Args: + code: the authorization code which would have been returned by the OIDC + provider + state: the "state" param which would have been passed around in the + query param. Should be the same as was embedded in the session in + _build_oidc_session. + session: the "session" which would have been passed around in the cookie. + user_agent: the user-agent to present + ip_address: the IP address to pretend the request came from + """ + request = Mock( + spec=[ + "args", + "getCookie", + "addCookie", + "requestHeaders", + "getClientIP", + "get_user_agent", + ] + ) + + request.getCookie.return_value = session + request.args = {} + request.args[b"code"] = [code.encode("utf-8")] + request.args[b"state"] = [state.encode("utf-8")] + request.getClientIP.return_value = ip_address + request.get_user_agent.return_value = user_agent + return request diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 8d50265145..f816594ee4 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -551,7 +551,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) def _get_login_flows(self) -> JsonDict: - _, channel = self.make_request("GET", "/_matrix/client/r0/login") + channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) return channel.json_body["flows"] @@ -560,7 +560,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): def _send_login(self, type, user, **params) -> FakeChannel: params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type}) - _, channel = self.make_request("POST", "/_matrix/client/r0/login", params) + channel = self.make_request("POST", "/_matrix/client/r0/login", params) return channel def _start_delete_device_session(self, access_token, device_id) -> str: @@ -597,7 +597,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"", ) -> FakeChannel: """Delete an individual device.""" - _, channel = self.make_request( + channel = self.make_request( "DELETE", "devices/" + device, body, access_token=access_token ) return channel diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index d21e5588ca..548038214b 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +from mock import Mock + import attr from synapse.api.errors import RedirectException -from synapse.handlers.sso import MappingException +from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config # Check if we have the dependencies to run the tests. @@ -44,6 +48,8 @@ BASE_URL = "https://synapse/" @attr.s class FakeAuthnResponse: ava = attr.ib(type=dict) + assertions = attr.ib(type=list, factory=list) + in_response_to = attr.ib(type=Optional[str], default=None) class TestMappingProvider: @@ -111,15 +117,22 @@ class SamlHandlerTestCase(HomeserverTestCase): def test_map_saml_response_to_user(self): """Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # send a mocked-up SAML response to the callback saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) - # The redirect_url doesn't matter with the default user mapping provider. - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "redirect_uri") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None ) - self.assertEqual(mxid, "@test_user:test") @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) def test_map_saml_response_to_existing_user(self): @@ -129,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase): store.register_user(user_id="@test_user:test", password_hash=None) ) + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + # Map a user via SSO. saml_response = FakeAuthnResponse( {"uid": "tester", "mxid": ["test_user"], "username": "test_user"} ) - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "", None ) - self.assertEqual(mxid, "@test_user:test") # Subsequent calls should map to the same mxid. - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + auth_handler.complete_sso_login.reset_mock() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "") + ) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "", None ) - self.assertEqual(mxid, "@test_user:test") def test_map_saml_response_to_invalid_localpart(self): """If the mapping provider generates an invalid localpart it should be rejected.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # mock out the error renderer too + sso_handler = self.hs.get_sso_handler() + sso_handler.render_error = Mock(return_value=None) + saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) - redirect_url = "" - e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), - MappingException, + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), + ) + sso_handler.render_error.assert_called_once_with( + request, "mapping_error", "localpart is invalid: föö" ) - self.assertEqual(str(e.value), "localpart is invalid: föö") + auth_handler.complete_sso_login.assert_not_called() def test_map_saml_response_to_user_retries(self): """The mapping provider can retry generating an MXID if the MXID is already in use.""" + + # stub out the auth handler and error renderer + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + sso_handler = self.hs.get_sso_handler() + sso_handler.render_error = Mock(return_value=None) + + # register a user to occupy the first-choice MXID store = self.hs.get_datastore() self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) + + # send the fake SAML response saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), ) + # test_user is already taken, so test_user1 gets registered instead. - self.assertEqual(mxid, "@test_user1:test") + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user1:test", request, "", None + ) + auth_handler.complete_sso_login.reset_mock() # Register all of the potential mxids for a particular SAML username. self.get_success( @@ -188,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase): # Now attempt to map to a username, this will fail since all potential usernames are taken. saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"}) - e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), - MappingException, + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), ) - self.assertEqual( - str(e.value), "Unable to generate a Matrix ID from the SSO response" + sso_handler.render_error.assert_called_once_with( + request, + "mapping_error", + "Unable to generate a Matrix ID from the SSO response", ) + auth_handler.complete_sso_login.assert_not_called() @override_config( { @@ -208,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase): } ) def test_map_saml_response_redirect(self): + """Test a mapping provider that raises a RedirectException""" + saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) - redirect_url = "" + request = _mock_request() e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), + self.handler._handle_authn_response(request, saml_response, ""), RedirectException, ) self.assertEqual(e.value.location, b"https://custom-saml-redirect/") + + +def _mock_request(): + """Returns a mock which will stand in as a SynapseRequest""" + return Mock(spec=["getClientIP", "get_user_agent"]) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f21de958f1..96e5bdac4a 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -220,7 +220,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - (request, channel) = self.make_request( + channel = self.make_request( "PUT", "/_matrix/federation/v1/send/1000000", _make_edu_transaction_json( diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 647a17cb90..9c886d671a 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -54,6 +54,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT ) ) + regular_user_id = "@regular:test" + self.get_success( + self.store.register_user(user_id=regular_user_id, password_hash=None) + ) self.get_success( self.handler.handle_local_profile_change(support_user_id, None) @@ -63,13 +67,47 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): display_name = "display_name" profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name) - regular_user_id = "@regular:test" self.get_success( self.handler.handle_local_profile_change(regular_user_id, profile_info) ) profile = self.get_success(self.store.get_user_in_directory(regular_user_id)) self.assertTrue(profile["display_name"] == display_name) + def test_handle_local_profile_change_with_deactivated_user(self): + # create user + r_user_id = "@regular:test" + self.get_success( + self.store.register_user(user_id=r_user_id, password_hash=None) + ) + + # update profile + display_name = "Regular User" + profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name) + self.get_success( + self.handler.handle_local_profile_change(r_user_id, profile_info) + ) + + # profile is in directory + profile = self.get_success(self.store.get_user_in_directory(r_user_id)) + self.assertTrue(profile["display_name"] == display_name) + + # deactivate user + self.get_success(self.store.set_user_deactivated_status(r_user_id, True)) + self.get_success(self.handler.handle_user_deactivated(r_user_id)) + + # profile is not in directory + profile = self.get_success(self.store.get_user_in_directory(r_user_id)) + self.assertTrue(profile is None) + + # update profile after deactivation + self.get_success( + self.handler.handle_local_profile_change(r_user_id, profile_info) + ) + + # profile is furthermore not in directory + profile = self.get_success(self.store.get_user_in_directory(r_user_id)) + self.assertTrue(profile is None) + def test_handle_user_deactivated_support_user(self): s_user_id = "@support:test" self.get_success( @@ -534,7 +572,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): self.helper.join(room, user=u2) # Assert user directory is not empty - request, channel = self.make_request( + channel = self.make_request( "POST", b"user_directory/search", b'{"search_term":"user2"}' ) self.assertEquals(200, channel.code, channel.result) @@ -542,7 +580,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): # Disable user directory and check search returns nothing self.config.user_directory_search_enabled = False - request, channel = self.make_request( + channel = self.make_request( "POST", b"user_directory/search", b'{"search_term":"user2"}' ) self.assertEquals(200, channel.code, channel.result) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 626acdcaa3..4e51839d0f 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -36,6 +36,7 @@ from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.srv_resolver import Server from synapse.http.federation.well_known_resolver import ( + WELL_KNOWN_MAX_SIZE, WellKnownResolver, _cache_period_from_headers, ) @@ -1107,6 +1108,32 @@ class MatrixFederationAgentTests(unittest.TestCase): r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, None) + def test_well_known_too_large(self): + """A well-known query that returns a result which is too large should be rejected.""" + self.reactor.lookups["testserv"] = "1.2.3.4" + + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) + + # there should be an attempt to connect on port 443 for the .well-known + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 443) + + self._handle_well_known_connection( + client_factory, + expected_sni=b"testserv", + response_headers={b"Cache-Control": b"max-age=1000"}, + content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }', + ) + + # The result is sucessful, but disabled delegation. + r = self.successResultOf(fetch_d) + self.assertIsNone(r.delegated_server) + def test_srv_fallbacks(self): """Test that other SRV results are tried if the first one fails. """ diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py index 05e9c449be..453391a5a5 100644 --- a/tests/http/test_additional_resource.py +++ b/tests/http/test_additional_resource.py @@ -46,16 +46,16 @@ class AdditionalResourceTests(HomeserverTestCase): handler = _AsyncTestCustomEndpoint({}, None).handle_request resource = AdditionalResource(self.hs, handler) - request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/") + channel = make_request(self.reactor, FakeSite(resource), "GET", "/") - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) def test_sync(self): handler = _SyncTestCustomEndpoint({}, None).handle_request resource = AdditionalResource(self.hs, handler) - request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/") + channel = make_request(self.reactor, FakeSite(resource), "GET", "/") - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index bcdcafa5a9..961bf09de9 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -209,7 +209,7 @@ class EmailPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - last_stream_ordering = pushers[0]["last_stream_ordering"] + last_stream_ordering = pushers[0].last_stream_ordering # Advance time a bit, so the pusher will register something has happened self.pump(10) @@ -220,7 +220,7 @@ class EmailPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) + self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering) # One email was attempted to be sent self.assertEqual(len(self.email_attempts), 1) @@ -238,4 +238,4 @@ class EmailPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) + self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index 8b4af74c51..60f0820cff 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -144,7 +144,7 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - last_stream_ordering = pushers[0]["last_stream_ordering"] + last_stream_ordering = pushers[0].last_stream_ordering # Advance time a bit, so the pusher will register something has happened self.pump() @@ -155,7 +155,7 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) + self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering) # One push was attempted to be sent -- it'll be the first message self.assertEqual(len(self.push_attempts), 1) @@ -176,8 +176,8 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) - last_stream_ordering = pushers[0]["last_stream_ordering"] + self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) + last_stream_ordering = pushers[0].last_stream_ordering # Now it'll try and send the second push message, which will be the second one self.assertEqual(len(self.push_attempts), 2) @@ -198,7 +198,7 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) + self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) def test_sends_high_priority_for_encrypted(self): """ @@ -667,7 +667,7 @@ class HTTPPusherTests(HomeserverTestCase): # This will actually trigger a new notification to be sent out so that # even if the user does not receive another message, their unread # count goes down - request, channel = self.make_request( + channel = self.make_request( "POST", "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id), {}, diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py index fe9e4d5f9a..f35a5235e1 100644 --- a/tests/replication/test_auth.py +++ b/tests/replication/test_auth.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Tuple -from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha import register from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -47,7 +45,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): return config - def _test_register(self) -> Tuple[SynapseRequest, FakeChannel]: + def _test_register(self) -> FakeChannel: """Run the actual test: 1. Create a worker homeserver. @@ -59,14 +57,14 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): worker_hs = self.make_worker_hs("synapse.app.client_reader") site = self._hs_to_site[worker_hs] - request_1, channel_1 = make_request( + channel_1 = make_request( self.reactor, site, "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, - ) # type: SynapseRequest, FakeChannel - self.assertEqual(request_1.code, 401) + ) + self.assertEqual(channel_1.code, 401) # Grab the session session = channel_1.json_body["session"] @@ -83,8 +81,8 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): def test_no_auth(self): """With no authentication the request should finish. """ - request, channel = self._test_register() - self.assertEqual(request.code, 200) + channel = self._test_register() + self.assertEqual(channel.code, 200) # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") @@ -93,8 +91,8 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): def test_missing_auth(self): """If the main process expects a secret that is not provided, an error results. """ - request, channel = self._test_register() - self.assertEqual(request.code, 500) + channel = self._test_register() + self.assertEqual(channel.code, 500) @override_config( { @@ -105,15 +103,15 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): def test_unauthorized(self): """If the main process receives the wrong secret, an error results. """ - request, channel = self._test_register() - self.assertEqual(request.code, 500) + channel = self._test_register() + self.assertEqual(channel.code, 500) @override_config({"worker_replication_secret": "my-secret"}) def test_authorized(self): """The request should finish when the worker provides the authentication header. """ - request, channel = self._test_register() - self.assertEqual(request.code, 200) + channel = self._test_register() + self.assertEqual(channel.code, 200) # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index fdaad3d8ad..4608b65a0c 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -14,11 +14,10 @@ # limitations under the License. import logging -from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha import register from tests.replication._base import BaseMultiWorkerStreamTestCase -from tests.server import FakeChannel, make_request +from tests.server import make_request logger = logging.getLogger(__name__) @@ -41,27 +40,27 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): worker_hs = self.make_worker_hs("synapse.app.client_reader") site = self._hs_to_site[worker_hs] - request_1, channel_1 = make_request( + channel_1 = make_request( self.reactor, site, "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, - ) # type: SynapseRequest, FakeChannel - self.assertEqual(request_1.code, 401) + ) + self.assertEqual(channel_1.code, 401) # Grab the session session = channel_1.json_body["session"] # also complete the dummy auth - request_2, channel_2 = make_request( + channel_2 = make_request( self.reactor, site, "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}, - ) # type: SynapseRequest, FakeChannel - self.assertEqual(request_2.code, 200) + ) + self.assertEqual(channel_2.code, 200) # We're given a registered user. self.assertEqual(channel_2.json_body["user_id"], "@user:test") @@ -73,28 +72,28 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): worker_hs_2 = self.make_worker_hs("synapse.app.client_reader") site_1 = self._hs_to_site[worker_hs_1] - request_1, channel_1 = make_request( + channel_1 = make_request( self.reactor, site_1, "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, - ) # type: SynapseRequest, FakeChannel - self.assertEqual(request_1.code, 401) + ) + self.assertEqual(channel_1.code, 401) # Grab the session session = channel_1.json_body["session"] # also complete the dummy auth site_2 = self._hs_to_site[worker_hs_2] - request_2, channel_2 = make_request( + channel_2 = make_request( self.reactor, site_2, "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}, - ) # type: SynapseRequest, FakeChannel - self.assertEqual(request_2.code, 200) + ) + self.assertEqual(channel_2.code, 200) # We're given a registered user. self.assertEqual(channel_2.json_body["user_id"], "@user:test") diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 83afd9fd2f..d1feca961f 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -68,7 +68,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): the media which the caller should respond to. """ resource = hs.get_media_repository_resource().children[b"download"] - _, channel = make_request( + channel = make_request( self.reactor, FakeSite(resource), "GET", diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 77fc3856d5..8d494ebc03 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -180,7 +180,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ) # Do an initial sync so that we're up to date. - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token ) next_batch = channel.json_body["next_batch"] @@ -206,7 +206,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # Check that syncing still gets the new event, despite the gap in the # stream IDs. - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -236,7 +236,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token) first_event_in_room2 = response["event_id"] - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -261,7 +261,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token) self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token) - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -279,7 +279,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # Paginating back in the first room should not produce any results, as # no events have happened in it. This tests that we are correctly # filtering results based on the vector clock portion. - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -292,7 +292,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # Paginating back on the second room should produce the first event # again. This tests that pagination isn't completely broken. - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -307,7 +307,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ) # Paginating forwards should give the same results - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -318,7 +318,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ) self.assertListEqual([], channel.json_body["chunk"]) - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 67d8878395..0504cd187e 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -42,7 +42,7 @@ class VersionTestCase(unittest.HomeserverTestCase): return resource def test_version_string(self): - request, channel = self.make_request("GET", self.url, shorthand=False) + channel = self.make_request("GET", self.url, shorthand=False) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -68,7 +68,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): def test_delete_group(self): # Create a new group - request, channel = self.make_request( + channel = self.make_request( "POST", "/create_group".encode("ascii"), access_token=self.admin_user_tok, @@ -84,13 +84,13 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): # Invite/join another user url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={} ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) url = "/groups/%s/self/accept_invite" % (group_id,) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), access_token=self.other_user_token, content={} ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -101,7 +101,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): # Now delete the group url = "/_synapse/admin/v1/delete_group/" + group_id - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), access_token=self.admin_user_tok, @@ -123,7 +123,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): """ url = "/groups/%s/profile" % (group_id,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) @@ -134,7 +134,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token) """ - request, channel = self.make_request( + channel = self.make_request( "GET", "/joined_groups".encode("ascii"), access_token=access_token ) @@ -216,7 +216,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): def _ensure_quarantined(self, admin_user_tok, server_and_media_id): """Ensure a piece of media is quarantined when trying to access it.""" - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(self.download_resource), "GET", @@ -241,7 +241,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Attempt quarantine media APIs as non-admin url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345" - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), access_token=non_admin_user_tok, ) @@ -254,7 +254,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # And the roomID/userID endpoint url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine" - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), access_token=non_admin_user_tok, ) @@ -282,7 +282,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): server_name, media_id = server_name_and_media_id.split("/") # Attempt to access the media - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(self.download_resource), "GET", @@ -299,7 +299,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): urllib.parse.quote(server_name), urllib.parse.quote(media_id), ) - request, channel = self.make_request("POST", url, access_token=admin_user_tok,) + channel = self.make_request("POST", url, access_token=admin_user_tok,) self.pump(1.0) self.assertEqual(200, int(channel.code), msg=channel.result["body"]) @@ -351,7 +351,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote( room_id ) - request, channel = self.make_request("POST", url, access_token=admin_user_tok,) + channel = self.make_request("POST", url, access_token=admin_user_tok,) self.pump(1.0) self.assertEqual(200, int(channel.code), msg=channel.result["body"]) self.assertEqual( @@ -395,7 +395,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( non_admin_user ) - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), access_token=admin_user_tok, ) self.pump(1.0) @@ -437,7 +437,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( non_admin_user ) - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), access_token=admin_user_tok, ) self.pump(1.0) @@ -453,7 +453,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self._ensure_quarantined(admin_user_tok, server_and_media_id_1) # Attempt to access each piece of media - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(self.download_resource), "GET", diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index cf3a007598..248c4442c3 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -50,17 +50,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): """ Try to get a device of an user without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - request, channel = self.make_request("PUT", self.url, b"{}") + channel = self.make_request("PUT", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - request, channel = self.make_request("DELETE", self.url, b"{}") + channel = self.make_request("DELETE", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -69,21 +69,21 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): """ If the user is not a server admin, an error is returned. """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url, access_token=self.other_user_token, ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url, access_token=self.other_user_token, ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - request, channel = self.make_request( + channel = self.make_request( "DELETE", self.url, access_token=self.other_user_token, ) @@ -99,23 +99,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): % self.other_user_device_id ) - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - request, channel = self.make_request( - "PUT", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("PUT", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - request, channel = self.make_request( - "DELETE", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -129,23 +123,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): % self.other_user_device_id ) - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - request, channel = self.make_request( - "PUT", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("PUT", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - request, channel = self.make_request( - "DELETE", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -158,22 +146,16 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.other_user ) - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - request, channel = self.make_request( - "PUT", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("PUT", url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) - request, channel = self.make_request( - "DELETE", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) # Delete unknown device returns status 200 self.assertEqual(200, channel.code, msg=channel.json_body) @@ -197,7 +179,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): } body = json.dumps(update) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url, access_token=self.admin_user_tok, @@ -208,9 +190,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"]) # Ensure the display name was not updated. - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) @@ -227,16 +207,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ) ) - request, channel = self.make_request( - "PUT", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("PUT", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) # Ensure the display name was not updated. - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) @@ -247,7 +223,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): """ # Set new display_name body = json.dumps({"display_name": "new displayname"}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url, access_token=self.admin_user_tok, @@ -257,9 +233,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) # Check new display_name - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new displayname", channel.json_body["display_name"]) @@ -268,9 +242,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): """ Tests that a normal lookup for a device is successfully """ - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) @@ -291,7 +263,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(1, number_devices) # Delete device - request, channel = self.make_request( + channel = self.make_request( "DELETE", self.url, access_token=self.admin_user_tok, ) @@ -323,7 +295,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ Try to list devices of an user without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -334,9 +306,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - request, channel = self.make_request( - "GET", self.url, access_token=other_user_token, - ) + channel = self.make_request("GET", self.url, access_token=other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -346,9 +316,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/devices" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -359,9 +327,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -373,9 +339,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ # Get devices - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -391,9 +355,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.login("user", "pass") # Get devices - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_devices, channel.json_body["total"]) @@ -431,7 +393,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): """ Try to delete devices of an user without authentication. """ - request, channel = self.make_request("POST", self.url, b"{}") + channel = self.make_request("POST", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -442,9 +404,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - request, channel = self.make_request( - "POST", self.url, access_token=other_user_token, - ) + channel = self.make_request("POST", self.url, access_token=other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -454,9 +414,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices" - request, channel = self.make_request( - "POST", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("POST", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -467,9 +425,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices" - request, channel = self.make_request( - "POST", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("POST", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -479,7 +435,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): Tests that a remove of a device that does not exist returns 200. """ body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, @@ -510,7 +466,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): # Delete devices body = json.dumps({"devices": device_ids}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index 11b72c10f7..aa389df12f 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -74,7 +74,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ Try to get an event report without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -84,9 +84,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): If the user is not a server admin, an error 403 is returned. """ - request, channel = self.make_request( - "GET", self.url, access_token=self.other_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.other_user_tok,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -96,9 +94,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events """ - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -111,7 +107,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with limit """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=5", access_token=self.admin_user_tok, ) @@ -126,7 +122,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with a defined starting point (from) """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=5", access_token=self.admin_user_tok, ) @@ -141,7 +137,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with a defined starting point and limit """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, ) @@ -156,7 +152,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with a filter of room """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?room_id=%s" % self.room_id1, access_token=self.admin_user_tok, @@ -176,7 +172,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with a filter of user """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?user_id=%s" % self.other_user, access_token=self.admin_user_tok, @@ -196,7 +192,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with a filter of user and room """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1), access_token=self.admin_user_tok, @@ -218,7 +214,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ # fetch the most recent first, largest timestamp - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?dir=b", access_token=self.admin_user_tok, ) @@ -234,7 +230,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): report += 1 # fetch the oldest first, smallest timestamp - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?dir=f", access_token=self.admin_user_tok, ) @@ -254,7 +250,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing that a invalid search order returns a 400 """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?dir=bar", access_token=self.admin_user_tok, ) @@ -267,7 +263,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing that a negative limit parameter returns a 400 """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, ) @@ -279,7 +275,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing that a negative from parameter returns a 400 """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=-5", access_token=self.admin_user_tok, ) @@ -293,7 +289,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of results is the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=20", access_token=self.admin_user_tok, ) @@ -304,7 +300,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of max results is larger than the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=21", access_token=self.admin_user_tok, ) @@ -315,7 +311,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # `next_token` does appear # Number of max results is smaller than the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=19", access_token=self.admin_user_tok, ) @@ -327,7 +323,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # Check # Set `from` to value of `next_token` for request remaining entries # `next_token` does not appear - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=19", access_token=self.admin_user_tok, ) @@ -342,7 +338,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] - request, channel = self.make_request( + channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), json.dumps({"score": -100, "reason": "this makes me sad"}), @@ -399,7 +395,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): """ Try to get event report without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -409,9 +405,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): If the user is not a server admin, an error 403 is returned. """ - request, channel = self.make_request( - "GET", self.url, access_token=self.other_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.other_user_tok,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -421,9 +415,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): Testing get a reported event """ - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self._check_fields(channel.json_body) @@ -434,7 +426,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): """ # `report_id` is negative - request, channel = self.make_request( + channel = self.make_request( "GET", "/_synapse/admin/v1/event_reports/-123", access_token=self.admin_user_tok, @@ -448,7 +440,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): ) # `report_id` is a non-numerical string - request, channel = self.make_request( + channel = self.make_request( "GET", "/_synapse/admin/v1/event_reports/abcdef", access_token=self.admin_user_tok, @@ -462,7 +454,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): ) # `report_id` is undefined - request, channel = self.make_request( + channel = self.make_request( "GET", "/_synapse/admin/v1/event_reports/", access_token=self.admin_user_tok, @@ -480,7 +472,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): Testing that a not existing `report_id` returns a 404. """ - request, channel = self.make_request( + channel = self.make_request( "GET", "/_synapse/admin/v1/event_reports/123", access_token=self.admin_user_tok, @@ -496,7 +488,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] - request, channel = self.make_request( + channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), json.dumps({"score": -100, "reason": "this makes me sad"}), diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index dadf9db660..c2b998cdae 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -50,7 +50,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") - request, channel = self.make_request("DELETE", url, b"{}") + channel = self.make_request("DELETE", url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -64,9 +64,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") - request, channel = self.make_request( - "DELETE", url, access_token=self.other_user_token, - ) + channel = self.make_request("DELETE", url, access_token=self.other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -77,9 +75,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") - request, channel = self.make_request( - "DELETE", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -90,9 +86,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345") - request, channel = self.make_request( - "DELETE", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) @@ -121,7 +115,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.assertEqual(server_name, self.server_name) # Attempt to access media - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(download_resource), "GET", @@ -146,9 +140,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id) # Delete media - request, channel = self.make_request( - "DELETE", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) @@ -157,7 +149,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): ) # Attempt to access media - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(download_resource), "GET", @@ -204,7 +196,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): Try to delete media without authentication. """ - request, channel = self.make_request("POST", self.url, b"{}") + channel = self.make_request("POST", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -216,7 +208,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.other_user = self.register_user("user", "pass") self.other_user_token = self.login("user", "pass") - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, access_token=self.other_user_token, ) @@ -229,7 +221,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain" - request, channel = self.make_request( + channel = self.make_request( "POST", url + "?before_ts=1234", access_token=self.admin_user_tok, ) @@ -240,9 +232,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): """ If the parameter `before_ts` is missing, an error is returned. """ - request, channel = self.make_request( - "POST", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("POST", self.url, access_token=self.admin_user_tok,) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) @@ -254,7 +244,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): """ If parameters are invalid, an error is returned. """ - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok, ) @@ -265,7 +255,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=1234&size_gt=-1234", access_token=self.admin_user_tok, @@ -278,7 +268,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=1234&keep_profiles=not_bool", access_token=self.admin_user_tok, @@ -308,7 +298,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): # timestamp after upload/create now_ms = self.clock.time_msec() - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, @@ -332,7 +322,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, @@ -344,7 +334,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): # timestamp after upload now_ms = self.clock.time_msec() - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, @@ -367,7 +357,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id) now_ms = self.clock.time_msec() - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms) + "&size_gt=67", access_token=self.admin_user_tok, @@ -378,7 +368,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id) now_ms = self.clock.time_msec() - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms) + "&size_gt=66", access_token=self.admin_user_tok, @@ -401,7 +391,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id) # set media as avatar - request, channel = self.make_request( + channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.admin_user,), content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}), @@ -410,7 +400,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, @@ -421,7 +411,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id) now_ms = self.clock.time_msec() - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, @@ -445,7 +435,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): # set media as room avatar room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/state/m.room.avatar" % (room_id,), content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}), @@ -454,7 +444,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, @@ -465,7 +455,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id) now_ms = self.clock.time_msec() - request, channel = self.make_request( + channel = self.make_request( "POST", self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, @@ -512,7 +502,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): media_id = server_and_media_id.split("/")[1] local_path = self.filepaths.local_media_filepath(media_id) - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(download_resource), "GET", diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 9c100050d2..fa620f97f3 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -20,6 +20,7 @@ from typing import List, Optional from mock import Mock import synapse.rest.admin +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes from synapse.rest.client.v1 import directory, events, login, room @@ -79,7 +80,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): # Test that the admin can still send shutdown url = "/_synapse/admin/v1/shutdown_room/" + room_id - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), @@ -103,7 +104,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): # Enable world readable url = "rooms/%s/state/m.room.history_visibility" % (room_id,) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), json.dumps({"history_visibility": "world_readable"}), @@ -113,7 +114,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): # Test that the admin can still send shutdown url = "/_synapse/admin/v1/shutdown_room/" + room_id - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), @@ -130,7 +131,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): """ url = "rooms/%s/initialSync" % (room_id,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.assertEqual( @@ -138,7 +139,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): ) url = "events?timeout=0&room_id=" + room_id - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.assertEqual( @@ -184,7 +185,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): If the user is not a server admin, an error 403 is returned. """ - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, json.dumps({}), access_token=self.other_user_tok, ) @@ -197,7 +198,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/rooms/!unknown:test/delete" - request, channel = self.make_request( + channel = self.make_request( "POST", url, json.dumps({}), access_token=self.admin_user_tok, ) @@ -210,7 +211,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/rooms/invalidroom/delete" - request, channel = self.make_request( + channel = self.make_request( "POST", url, json.dumps({}), access_token=self.admin_user_tok, ) @@ -225,7 +226,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"new_room_user_id": "@unknown:test"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -244,7 +245,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"new_room_user_id": "@not:exist.bla"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -262,7 +263,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"block": "NotBool"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -278,7 +279,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"purge": "NotBool"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -304,7 +305,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"block": True, "purge": True}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url.encode("ascii"), content=body.encode(encoding="utf_8"), @@ -337,7 +338,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"block": False, "purge": True}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url.encode("ascii"), content=body.encode(encoding="utf_8"), @@ -371,7 +372,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"block": False, "purge": False}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url.encode("ascii"), content=body.encode(encoding="utf_8"), @@ -418,7 +419,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Test that the admin can still send shutdown url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), @@ -448,7 +449,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Enable world readable url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), json.dumps({"history_visibility": "world_readable"}), @@ -465,7 +466,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Test that the admin can still send shutdown url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), @@ -530,7 +531,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ url = "rooms/%s/initialSync" % (room_id,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.assertEqual( @@ -538,7 +539,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): ) url = "events?timeout=0&room_id=" + room_id - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.assertEqual( @@ -569,7 +570,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) url = "/_synapse/admin/v1/purge_room" - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), {"room_id": room_id}, @@ -623,7 +624,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # Request the list of rooms url = "/_synapse/admin/v1/rooms" - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) @@ -704,7 +705,7 @@ class RoomTestCase(unittest.HomeserverTestCase): limit, "name", ) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual( @@ -744,7 +745,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_ids, returned_room_ids) url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -764,7 +765,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # Create a new alias to this room url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), {"room_id": room_id}, @@ -794,7 +795,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # Request the list of rooms url = "/_synapse/admin/v1/rooms" - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -835,7 +836,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url = "/_matrix/client/r0/directory/room/%s" % ( urllib.parse.quote(test_alias), ) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), {"room_id": room_id}, @@ -875,7 +876,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,) if reverse: url += "&dir=b" - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1011,7 +1012,7 @@ class RoomTestCase(unittest.HomeserverTestCase): expected_http_code: The expected http code for the request """ url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) @@ -1050,6 +1051,13 @@ class RoomTestCase(unittest.HomeserverTestCase): _search_test(room_id_2, "else") _search_test(room_id_2, "se") + # Test case insensitive + _search_test(room_id_1, "SOMETHING") + _search_test(room_id_1, "THING") + + _search_test(room_id_2, "ELSE") + _search_test(room_id_2, "SE") + _search_test(None, "foo") _search_test(None, "bar") _search_test(None, "", expected_http_code=400) @@ -1072,7 +1080,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1102,7 +1110,7 @@ class RoomTestCase(unittest.HomeserverTestCase): room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1114,7 +1122,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.helper.join(room_id_1, user_1, tok=user_tok_1) url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1124,7 +1132,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.helper.leave(room_id_1, self.admin_user, tok=self.admin_user_tok) self.helper.leave(room_id_1, user_1, tok=user_tok_1) url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1153,7 +1161,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.helper.join(room_id_2, user_3, tok=user_tok_3) url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1164,7 +1172,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["total"], 3) url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1204,7 +1212,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"user_id": self.second_user_id}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -1220,7 +1228,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"unknown_parameter": "@unknown:test"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -1236,7 +1244,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"user_id": "@unknown:test"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -1252,7 +1260,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"user_id": "@not:exist.bla"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -1272,7 +1280,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"user_id": self.second_user_id}) url = "/_synapse/admin/v1/join/!unknown:test" - request, channel = self.make_request( + channel = self.make_request( "POST", url, content=body.encode(encoding="utf_8"), @@ -1289,7 +1297,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"user_id": self.second_user_id}) url = "/_synapse/admin/v1/join/invalidroom" - request, channel = self.make_request( + channel = self.make_request( "POST", url, content=body.encode(encoding="utf_8"), @@ -1308,7 +1316,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"user_id": self.second_user_id}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), @@ -1320,7 +1328,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Validate if user is a member of the room - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1337,7 +1345,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/join/{}".format(private_room_id) body = json.dumps({"user_id": self.second_user_id}) - request, channel = self.make_request( + channel = self.make_request( "POST", url, content=body.encode(encoding="utf_8"), @@ -1367,7 +1375,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Validate if server admin is a member of the room - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1378,7 +1386,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/join/{}".format(private_room_id) body = json.dumps({"user_id": self.second_user_id}) - request, channel = self.make_request( + channel = self.make_request( "POST", url, content=body.encode(encoding="utf_8"), @@ -1389,7 +1397,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Validate if user is a member of the room - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1406,7 +1414,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/join/{}".format(private_room_id) body = json.dumps({"user_id": self.second_user_id}) - request, channel = self.make_request( + channel = self.make_request( "POST", url, content=body.encode(encoding="utf_8"), @@ -1418,13 +1426,150 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Validate if user is a member of the room - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) +class MakeRoomAdminTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.public_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + self.url = "/_synapse/admin/v1/rooms/{}/make_room_admin".format( + self.public_room_id + ) + + def test_public_room(self): + """Test that getting admin in a public room works. + """ + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Now we test that we can join the room and ban a user. + self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) + self.helper.change_membership( + room_id, + self.admin_user, + "@test:test", + Membership.BAN, + tok=self.admin_user_tok, + ) + + def test_private_room(self): + """Test that getting admin in a private room works and we get invited. + """ + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False, + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Now we test that we can join the room (we should have received an + # invite) and can ban a user. + self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) + self.helper.change_membership( + room_id, + self.admin_user, + "@test:test", + Membership.BAN, + tok=self.admin_user_tok, + ) + + def test_other_user(self): + """Test that giving admin in a public room works to a non-admin user works. + """ + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + content={"user_id": self.second_user_id}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Now we test that we can join the room and ban a user. + self.helper.join(room_id, self.second_user_id, tok=self.second_tok) + self.helper.change_membership( + room_id, + self.second_user_id, + "@test:test", + Membership.BAN, + tok=self.second_tok, + ) + + def test_not_enough_power(self): + """Test that we get a sensible error if there are no local room admins. + """ + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + + # The creator drops admin rights in the room. + pl = self.helper.get_state( + room_id, EventTypes.PowerLevels, tok=self.creator_tok + ) + pl["users"][self.creator] = 0 + self.helper.send_state( + room_id, EventTypes.PowerLevels, body=pl, tok=self.creator_tok + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + content={}, + access_token=self.admin_user_tok, + ) + + # We expect this to fail with a 400 as there are no room admins. + # + # (Note we assert the error message to ensure that it's not denied for + # some other reason) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + channel.json_body["error"], + "No local admin user in room with power to update power levels.", + ) + + PURGE_TABLES = [ "current_state_events", "event_backward_extremities", @@ -1453,7 +1598,6 @@ PURGE_TABLES = [ "event_push_summary", "pusher_throttle", "group_summary_rooms", - "local_invites", "room_account_data", "room_tags", # "state_groups", # Current impl leaves orphaned state groups around. diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 907b49f889..73f8a8ec99 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -46,7 +46,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): """ Try to list users without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -55,7 +55,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): """ If the user is not a server admin, an error 403 is returned. """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url, json.dumps({}), access_token=self.other_user_tok, ) @@ -67,7 +67,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): If parameters are invalid, an error is returned. """ # unkown order_by - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?order_by=bar", access_token=self.admin_user_tok, ) @@ -75,7 +75,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=-5", access_token=self.admin_user_tok, ) @@ -83,7 +83,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative limit - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, ) @@ -91,7 +91,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from_ts - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from_ts=-1234", access_token=self.admin_user_tok, ) @@ -99,7 +99,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative until_ts - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?until_ts=-1234", access_token=self.admin_user_tok, ) @@ -107,7 +107,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # until_ts smaller from_ts - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from_ts=10&until_ts=5", access_token=self.admin_user_tok, @@ -117,7 +117,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # empty search term - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?search_term=", access_token=self.admin_user_tok, ) @@ -125,7 +125,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?dir=bar", access_token=self.admin_user_tok, ) @@ -138,7 +138,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): """ self._create_users_with_media(10, 2) - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=5", access_token=self.admin_user_tok, ) @@ -154,7 +154,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): """ self._create_users_with_media(20, 2) - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=5", access_token=self.admin_user_tok, ) @@ -170,7 +170,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): """ self._create_users_with_media(20, 2) - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, ) @@ -190,7 +190,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of results is the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=20", access_token=self.admin_user_tok, ) @@ -201,7 +201,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of max results is larger than the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=21", access_token=self.admin_user_tok, ) @@ -212,7 +212,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # `next_token` does appear # Number of max results is smaller than the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=19", access_token=self.admin_user_tok, ) @@ -223,7 +223,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # Set `from` to value of `next_token` for request remaining entries # Check `next_token` does not appear - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=19", access_token=self.admin_user_tok, ) @@ -238,9 +238,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): if users have no media created """ - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -316,15 +314,13 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): ts1 = self.clock.time_msec() # list all media when filter is not set - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media starting at `ts1` after creating first media # result is 0 - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -337,7 +333,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self._create_media(self.other_user_tok, 3) # filter media between `ts1` and `ts2` - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2), access_token=self.admin_user_tok, @@ -346,7 +342,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media until `ts2` and earlier - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -356,14 +352,12 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self._create_users_with_media(20, 1) # check without filter get all users - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) # filter user 1 and 10-19 by `user_id` - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?search_term=foo_user_1", access_token=self.admin_user_tok, @@ -372,7 +366,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["total"], 11) # filter on this user in `displayname` - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?search_term=bar_user_10", access_token=self.admin_user_tok, @@ -382,7 +376,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["total"], 1) # filter and get empty result - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?search_term=foobar", access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -447,7 +441,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): url = self.url + "?order_by=%s" % (order_type,) if dir is not None and dir in ("b", "f"): url += "&dir=%s" % (dir,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ba1438cdc7..9b2e4765f6 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -18,6 +18,7 @@ import hmac import json import urllib.parse from binascii import unhexlify +from typing import Optional from mock import Mock @@ -70,7 +71,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): """ self.hs.config.registration_shared_secret = None - request, channel = self.make_request("POST", self.url, b"{}") + channel = self.make_request("POST", self.url, b"{}") self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -87,7 +88,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.hs.get_secrets = Mock(return_value=secrets) - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) self.assertEqual(channel.json_body, {"nonce": "abcd"}) @@ -96,14 +97,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): Calling GET on the endpoint will return a randomised nonce, which will only last for SALT_TIMEOUT (60s). """ - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] # 59 seconds self.reactor.advance(59) body = json.dumps({"nonce": nonce}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("username must be specified", channel.json_body["error"]) @@ -111,7 +112,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # 61 seconds self.reactor.advance(2) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("unrecognised nonce", channel.json_body["error"]) @@ -120,7 +121,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): """ Only the provided nonce can be used, as it's checked in the MAC. """ - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -136,7 +137,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("HMAC incorrect", channel.json_body["error"]) @@ -146,7 +147,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): When the correct nonce is provided, and the right key is provided, the user is registered. """ - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -165,7 +166,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -174,7 +175,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): """ A valid unrecognised nonce. """ - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -190,13 +191,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("unrecognised nonce", channel.json_body["error"]) @@ -209,7 +210,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): """ def nonce(): - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) return channel.json_body["nonce"] # @@ -218,7 +219,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("nonce must be specified", channel.json_body["error"]) @@ -229,28 +230,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({"nonce": nonce()}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": 1234}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid username", channel.json_body["error"]) @@ -261,28 +262,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({"nonce": nonce(), "username": "a"}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid password", channel.json_body["error"]) # Super long body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid password", channel.json_body["error"]) @@ -300,7 +301,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "user_type": "invalid", } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid user type", channel.json_body["error"]) @@ -311,7 +312,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): """ # set no displayname - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -321,17 +322,17 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = json.dumps( {"nonce": nonce, "username": "bob1", "password": "abc123", "mac": want_mac} ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob1:test", channel.json_body["user_id"]) - request, channel = self.make_request("GET", "/profile/@bob1:test/displayname") + channel = self.make_request("GET", "/profile/@bob1:test/displayname") self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("bob1", channel.json_body["displayname"]) # displayname is None - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -347,17 +348,17 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob2:test", channel.json_body["user_id"]) - request, channel = self.make_request("GET", "/profile/@bob2:test/displayname") + channel = self.make_request("GET", "/profile/@bob2:test/displayname") self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("bob2", channel.json_body["displayname"]) # displayname is empty - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -373,16 +374,16 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob3:test", channel.json_body["user_id"]) - request, channel = self.make_request("GET", "/profile/@bob3:test/displayname") + channel = self.make_request("GET", "/profile/@bob3:test/displayname") self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) # set displayname - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -398,12 +399,12 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob4:test", channel.json_body["user_id"]) - request, channel = self.make_request("GET", "/profile/@bob4:test/displayname") + channel = self.make_request("GET", "/profile/@bob4:test/displayname") self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Bob's Name", channel.json_body["displayname"]) @@ -429,7 +430,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) # Register new user with admin API - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -448,7 +449,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -466,23 +467,38 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - self.register_user("user1", "pass1", admin=False) - self.register_user("user2", "pass2", admin=False) + self.user1 = self.register_user( + "user1", "pass1", admin=False, displayname="Name 1" + ) + self.user2 = self.register_user( + "user2", "pass2", admin=False, displayname="Name 2" + ) def test_no_auth(self): """ Try to list users without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user1", "pass1") + + channel = self.make_request("GET", self.url, access_token=other_user_token) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_all_users(self): """ List all users, including deactivated users. """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?deactivated=true", b"{}", @@ -493,6 +509,83 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(3, len(channel.json_body["users"])) self.assertEqual(3, channel.json_body["total"]) + # Check that all fields are available + for u in channel.json_body["users"]: + self.assertIn("name", u) + self.assertIn("is_guest", u) + self.assertIn("admin", u) + self.assertIn("user_type", u) + self.assertIn("deactivated", u) + self.assertIn("displayname", u) + self.assertIn("avatar_url", u) + + def test_search_term(self): + """Test that searching for a users works correctly""" + + def _search_test( + expected_user_id: Optional[str], + search_term: str, + search_field: Optional[str] = "name", + expected_http_code: Optional[int] = 200, + ): + """Search for a user and check that the returned user's id is a match + + Args: + expected_user_id: The user_id expected to be returned by the API. Set + to None to expect zero results for the search + search_term: The term to search for user names with + search_field: Field which is to request: `name` or `user_id` + expected_http_code: The expected http code for the request + """ + url = self.url + "?%s=%s" % (search_field, search_term,) + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) + + if expected_http_code != 200: + return + + # Check that users were returned + self.assertTrue("users" in channel.json_body) + users = channel.json_body["users"] + + # Check that the expected number of users were returned + expected_user_count = 1 if expected_user_id else 0 + self.assertEqual(len(users), expected_user_count) + self.assertEqual(channel.json_body["total"], expected_user_count) + + if expected_user_id: + # Check that the first returned user id is correct + u = users[0] + self.assertEqual(expected_user_id, u["name"]) + + # Perform search tests + _search_test(self.user1, "er1") + _search_test(self.user1, "me 1") + + _search_test(self.user2, "er2") + _search_test(self.user2, "me 2") + + _search_test(self.user1, "er1", "user_id") + _search_test(self.user2, "er2", "user_id") + + # Test case insensitive + _search_test(self.user1, "ER1") + _search_test(self.user1, "NAME 1") + + _search_test(self.user2, "ER2") + _search_test(self.user2, "NAME 2") + + _search_test(self.user1, "ER1", "user_id") + _search_test(self.user2, "ER2", "user_id") + + _search_test(None, "foo") + _search_test(None, "bar") + + _search_test(None, "foo", "user_id") + _search_test(None, "bar", "user_id") + class UserRestTestCase(unittest.HomeserverTestCase): @@ -508,7 +601,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - self.other_user = self.register_user("user", "pass") + self.other_user = self.register_user("user", "pass", displayname="User") self.other_user_token = self.login("user", "pass") self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( self.other_user @@ -520,14 +613,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v2/users/@bob:test" - request, channel = self.make_request( - "GET", url, access_token=self.other_user_token, - ) + channel = self.make_request("GET", url, access_token=self.other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("You are not a server admin", channel.json_body["error"]) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.other_user_token, content=b"{}", ) @@ -539,7 +630,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ - request, channel = self.make_request( + channel = self.make_request( "GET", "/_synapse/admin/v2/users/@unknown_person:test", access_token=self.admin_user_tok, @@ -565,7 +656,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): } ) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, @@ -581,9 +672,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -612,7 +701,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): } ) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, @@ -628,9 +717,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -656,9 +743,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Sync to set admin user to active # before limit of monthly active users is reached - request, channel = self.make_request( - "GET", "/sync", access_token=self.admin_user_tok - ) + channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok) if channel.code != 200: raise HttpResponseException( @@ -681,7 +766,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Create user body = json.dumps({"password": "abc123", "admin": False}) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, @@ -720,7 +805,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Create user body = json.dumps({"password": "abc123", "admin": False}) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, @@ -757,7 +842,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): } ) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, @@ -774,7 +859,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertEqual("@bob:test", pushers[0]["user_name"]) + self.assertEqual("@bob:test", pushers[0].user_name) @override_config( { @@ -801,7 +886,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): } ) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, @@ -827,7 +912,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Change password body = json.dumps({"password": "hahaha"}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -844,7 +929,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Modify user body = json.dumps({"displayname": "foobar"}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -856,7 +941,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("foobar", channel.json_body["displayname"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) @@ -874,7 +959,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]} ) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -887,7 +972,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) @@ -904,7 +989,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Deactivate user body = json.dumps({"deactivated": True}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -917,7 +1002,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # the user is deactivated, the threepid will be deleted # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) @@ -925,13 +1010,61 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + @override_config({"user_directory": {"enabled": True, "search_all_users": True}}) + def test_change_name_deactivate_user_user_directory(self): + """ + Test change profile information of a deactivated user and + check that it does not appear in user directory + """ + + # is in user directory + profile = self.get_success(self.store.get_user_in_directory(self.other_user)) + self.assertTrue(profile["display_name"] == "User") + + # Deactivate user + body = json.dumps({"deactivated": True}) + + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + + # is not in user directory + profile = self.get_success(self.store.get_user_in_directory(self.other_user)) + self.assertTrue(profile is None) + + # Set new displayname user + body = json.dumps({"displayname": "Foobar"}) + + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual("Foobar", channel.json_body["displayname"]) + + # is not in user directory + profile = self.get_success(self.store.get_user_in_directory(self.other_user)) + self.assertTrue(profile is None) + def test_reactivate_user(self): """ Test reactivating another user. """ # Deactivate the user. - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -944,7 +1077,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self._is_erased("@user:test", True) # Attempt to reactivate the user (without a password). - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -953,7 +1086,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) # Reactivate the user. - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -964,7 +1097,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) @@ -981,7 +1114,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set a user as an admin body = json.dumps({"admin": True}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -993,7 +1126,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(True, channel.json_body["admin"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) @@ -1011,7 +1144,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Create user body = json.dumps({"password": "abc123"}) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, @@ -1023,9 +1156,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("bob", channel.json_body["displayname"]) # Get user - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -1035,7 +1166,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Change password (and use a str for deactivate instead of a bool) body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops! - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, @@ -1045,9 +1176,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) # Check user is not deactivated - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -1089,7 +1218,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ Try to list rooms of an user without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -1100,9 +1229,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - request, channel = self.make_request( - "GET", self.url, access_token=other_user_token, - ) + channel = self.make_request("GET", self.url, access_token=other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1112,9 +1239,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -1125,9 +1250,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -1138,9 +1261,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): if user has no memberships """ # Get rooms - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1157,9 +1278,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.helper.create_room_as(self.other_user, tok=other_user_tok) # Get rooms - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_rooms, channel.json_body["total"]) @@ -1188,7 +1307,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ Try to list pushers of an user without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -1199,9 +1318,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - request, channel = self.make_request( - "GET", self.url, access_token=other_user_token, - ) + channel = self.make_request("GET", self.url, access_token=other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1211,9 +1328,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/pushers" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -1224,9 +1339,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -1237,9 +1350,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ # Get pushers - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1266,9 +1377,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): ) # Get pushers - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) @@ -1307,7 +1416,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): """ Try to list media of an user without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -1318,9 +1427,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - request, channel = self.make_request( - "GET", self.url, access_token=other_user_token, - ) + channel = self.make_request("GET", self.url, access_token=other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1330,9 +1437,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/media" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -1343,9 +1448,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -1359,7 +1462,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") self._create_media(other_user_tok, number_media) - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=5", access_token=self.admin_user_tok, ) @@ -1378,7 +1481,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") self._create_media(other_user_tok, number_media) - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=5", access_token=self.admin_user_tok, ) @@ -1397,7 +1500,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") self._create_media(other_user_tok, number_media) - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, ) @@ -1412,7 +1515,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): Testing that a negative limit parameter returns a 400 """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, ) @@ -1424,7 +1527,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): Testing that a negative from parameter returns a 400 """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=-5", access_token=self.admin_user_tok, ) @@ -1442,7 +1545,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of results is the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=20", access_token=self.admin_user_tok, ) @@ -1453,7 +1556,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of max results is larger than the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=21", access_token=self.admin_user_tok, ) @@ -1464,7 +1567,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # `next_token` does appear # Number of max results is smaller than the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=19", access_token=self.admin_user_tok, ) @@ -1476,7 +1579,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # Check # Set `from` to value of `next_token` for request remaining entries # `next_token` does not appear - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=19", access_token=self.admin_user_tok, ) @@ -1491,9 +1594,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): if user has no media created """ - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1508,9 +1609,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") self._create_media(other_user_tok, number_media) - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) @@ -1576,7 +1675,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): ) def _get_token(self) -> str: - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, b"{}", access_token=self.admin_user_tok ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1585,7 +1684,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): def test_no_auth(self): """Try to login as a user without authentication. """ - request, channel = self.make_request("POST", self.url, b"{}") + channel = self.make_request("POST", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -1593,7 +1692,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): def test_not_admin(self): """Try to login as a user as a non-admin user. """ - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, b"{}", access_token=self.other_user_tok ) @@ -1621,7 +1720,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): self._get_token() # Check that we don't see a new device in our devices list - request, channel = self.make_request( + channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1636,25 +1735,19 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): puppet_token = self._get_token() # Test that we can successfully make a request - request, channel = self.make_request( - "GET", "devices", b"{}", access_token=puppet_token - ) + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Logout with the puppet token - request, channel = self.make_request( - "POST", "logout", b"{}", access_token=puppet_token - ) + channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # The puppet token should no longer work - request, channel = self.make_request( - "GET", "devices", b"{}", access_token=puppet_token - ) + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) # .. but the real user's tokens should still work - request, channel = self.make_request( + channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1667,25 +1760,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): puppet_token = self._get_token() # Test that we can successfully make a request - request, channel = self.make_request( - "GET", "devices", b"{}", access_token=puppet_token - ) + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Logout all with the real user token - request, channel = self.make_request( + channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.other_user_tok ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # The puppet token should still work - request, channel = self.make_request( - "GET", "devices", b"{}", access_token=puppet_token - ) + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # .. but the real user's tokens shouldn't - request, channel = self.make_request( + channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) @@ -1698,25 +1787,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): puppet_token = self._get_token() # Test that we can successfully make a request - request, channel = self.make_request( - "GET", "devices", b"{}", access_token=puppet_token - ) + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Logout all with the admin user token - request, channel = self.make_request( + channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.admin_user_tok ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # The puppet token should no longer work - request, channel = self.make_request( - "GET", "devices", b"{}", access_token=puppet_token - ) + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) # .. but the real user's tokens should still work - request, channel = self.make_request( + channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1798,11 +1883,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): """ Try to get information of an user without authentication. """ - request, channel = self.make_request("GET", self.url1, b"{}") + channel = self.make_request("GET", self.url1, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - request, channel = self.make_request("GET", self.url2, b"{}") + channel = self.make_request("GET", self.url2, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -1813,15 +1898,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.register_user("user2", "pass") other_user2_token = self.login("user2", "pass") - request, channel = self.make_request( - "GET", self.url1, access_token=other_user2_token, - ) + channel = self.make_request("GET", self.url1, access_token=other_user2_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - request, channel = self.make_request( - "GET", self.url2, access_token=other_user2_token, - ) + channel = self.make_request("GET", self.url2, access_token=other_user2_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1832,15 +1913,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain" - request, channel = self.make_request( - "GET", url1, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url1, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only whois a local user", channel.json_body["error"]) - request, channel = self.make_request( - "GET", url2, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", url2, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only whois a local user", channel.json_body["error"]) @@ -1848,16 +1925,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): """ The lookup should succeed for an admin. """ - request, channel = self.make_request( - "GET", self.url1, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url1, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) - request, channel = self.make_request( - "GET", self.url2, access_token=self.admin_user_tok, - ) + channel = self.make_request("GET", self.url2, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) @@ -1868,16 +1941,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - request, channel = self.make_request( - "GET", self.url1, access_token=other_user_token, - ) + channel = self.make_request("GET", self.url1, access_token=other_user_token,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) - request, channel = self.make_request( - "GET", self.url2, access_token=other_user_token, - ) + channel = self.make_request("GET", self.url2, access_token=other_user_token,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index e2e6a5e16d..c74693e9b2 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -61,7 +61,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): def test_render_public_consent(self): """You can observe the terms form without specifying a user""" resource = consent_resource.ConsentResource(self.hs) - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False ) self.assertEqual(channel.code, 200) @@ -82,7 +82,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "") + "&u=user" ) - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(resource), "GET", @@ -97,7 +97,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): self.assertEqual(consented, "False") # POST to the consent page, saying we've agreed - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(resource), "POST", @@ -109,7 +109,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): # Fetch the consent page, to get the consent version -- it should have # changed - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(resource), "GET", diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index a1ccc4ee9a..56937dcd2e 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -93,7 +93,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): def get_event(self, room_id, event_id, expected_code=200): url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) - request, channel = self.make_request("GET", url) + channel = self.make_request("GET", url) self.assertEqual(channel.code, expected_code, channel.result) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index 259c6a1985..c0a9fc6925 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -43,9 +43,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") - request, channel = self.make_request( - b"POST", "/createRoom", b"{}", access_token=tok - ) + channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) room_id = channel.json_body["room_id"] @@ -56,7 +54,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): } request_data = json.dumps(params) request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") - request, channel = self.make_request( + channel = self.make_request( b"POST", request_url, request_data, access_token=tok ) self.assertEquals(channel.result["code"], b"403", channel.result) diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index c1f516cc93..f0707646bb 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -69,16 +69,12 @@ class RedactionsTestCase(HomeserverTestCase): """ path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) - request, channel = self.make_request( - "POST", path, content={}, access_token=access_token - ) + channel = self.make_request("POST", path, content={}, access_token=access_token) self.assertEqual(int(channel.result["code"]), expect_code) return channel.json_body def _sync_room_timeline(self, access_token, room_id): - request, channel = self.make_request( - "GET", "sync", access_token=self.mod_access_token - ) + channel = self.make_request("GET", "sync", access_token=self.mod_access_token) self.assertEqual(channel.result["code"], b"200") room_sync = channel.json_body["rooms"]["join"][room_id] return room_sync["timeline"]["events"] diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index f56b5d9231..31dc832fd5 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -325,7 +325,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): def get_event(self, room_id, event_id, expected_code=200): url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) - request, channel = self.make_request("GET", url, access_token=self.token) + channel = self.make_request("GET", url, access_token=self.token) self.assertEqual(channel.code, expected_code, channel.result) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 94dcfb9f7c..e689c3fbea 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -89,7 +89,7 @@ class RoomTestCase(_ShadowBannedBase): ) # Inviting the user completes successfully. - request, channel = self.make_request( + channel = self.make_request( "POST", "/rooms/%s/invite" % (room_id,), {"id_server": "test", "medium": "email", "address": "test@test.test"}, @@ -103,7 +103,7 @@ class RoomTestCase(_ShadowBannedBase): def test_create_room(self): """Invitations during a room creation should be discarded, but the room still gets created.""" # The room creation is successful. - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/createRoom", {"visibility": "public", "invite": [self.other_user_id]}, @@ -158,7 +158,7 @@ class RoomTestCase(_ShadowBannedBase): self.banned_user_id, tok=self.banned_access_token ) - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,), {"new_version": "6"}, @@ -183,7 +183,7 @@ class RoomTestCase(_ShadowBannedBase): self.banned_user_id, tok=self.banned_access_token ) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (room_id, self.banned_user_id), {"typing": True, "timeout": 30000}, @@ -198,7 +198,7 @@ class RoomTestCase(_ShadowBannedBase): # The other user can join and send typing events. self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (room_id, self.other_user_id), {"typing": True, "timeout": 30000}, @@ -244,7 +244,7 @@ class ProfileTestCase(_ShadowBannedBase): ) # The update should succeed. - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,), {"displayname": new_display_name}, @@ -254,7 +254,7 @@ class ProfileTestCase(_ShadowBannedBase): self.assertEqual(channel.json_body, {}) # The user's display name should be updated. - request, channel = self.make_request( + channel = self.make_request( "GET", "/profile/%s/displayname" % (self.banned_user_id,) ) self.assertEqual(channel.code, 200, channel.result) @@ -282,7 +282,7 @@ class ProfileTestCase(_ShadowBannedBase): ) # The update should succeed. - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room_id, self.banned_user_id), diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 0e96697f9b..227fffab58 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -86,7 +86,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): callback = Mock(spec=[], side_effect=check) current_rules_module().check_event_allowed = callback - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id, {}, @@ -104,7 +104,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): self.assertEqual(ev.type, k[0]) self.assertEqual(ev.state_key, k[1]) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id, {}, @@ -123,7 +123,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): current_rules_module().check_event_allowed = check # now send the event - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id, {"x": "x"}, @@ -142,7 +142,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): current_rules_module().check_event_allowed = check # now send the event - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id, {"x": "x"}, @@ -152,7 +152,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): event_id = channel.json_body["event_id"] # ... and check that it got modified - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py index 7a2c653df8..edd1d184f8 100644 --- a/tests/rest/client/v1/test_directory.py +++ b/tests/rest/client/v1/test_directory.py @@ -91,7 +91,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # that we can make sure that the check is done on the whole alias. data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) self.assertEqual(channel.code, 400, channel.result) @@ -104,7 +104,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # as cautious as possible here. data = {"room_alias_name": random_string(5)} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) self.assertEqual(channel.code, 200, channel.result) @@ -118,7 +118,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): data = {"aliases": [self.random_alias(alias_length)]} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok ) self.assertEqual(channel.code, expected_code, channel.result) @@ -128,7 +128,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): data = {"room_id": self.room_id} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok ) self.assertEqual(channel.code, expected_code, channel.result) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 12a93f5687..0a5ca317ea 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -63,13 +63,13 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): # implementation is now part of the r0 implementation, the newer # behaviour is used instead to be consistent with the r0 spec. # see issue #2602 - request, channel = self.make_request( + channel = self.make_request( "GET", "/events?access_token=%s" % ("invalid" + self.token,) ) self.assertEquals(channel.code, 401, msg=channel.result) # valid token, expect content - request, channel = self.make_request( + channel = self.make_request( "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) self.assertEquals(channel.code, 200, msg=channel.result) @@ -87,7 +87,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): ) # valid token, expect content - request, channel = self.make_request( + channel = self.make_request( "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) self.assertEquals(channel.code, 200, msg=channel.result) @@ -149,7 +149,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase): resp = self.helper.send(self.room_id, tok=self.token) event_id = resp["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/events/" + event_id, access_token=self.token, ) self.assertEquals(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 176ddf7ec9..18932d7518 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -63,7 +63,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -82,7 +82,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -108,7 +108,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -127,7 +127,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -153,7 +153,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -172,7 +172,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.result["code"], b"403", channel.result) @@ -181,7 +181,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.register_user("kermit", "monkey") # we shouldn't be able to make requests without an access token - request, channel = self.make_request(b"GET", TEST_URL) + channel = self.make_request(b"GET", TEST_URL) self.assertEquals(channel.result["code"], b"401", channel.result) self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN") @@ -191,25 +191,21 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.code, 200, channel.result) access_token = channel.json_body["access_token"] device_id = channel.json_body["device_id"] # we should now be able to make requests with the access token - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], True) @@ -223,9 +219,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # more requests with the expired token should still return a soft-logout self.reactor.advance(3600) - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], True) @@ -233,16 +227,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # ... but if we delete that device, it will be a proper logout self._delete_device(access_token_2, "kermit", "monkey", device_id) - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], False) def _delete_device(self, access_token, user_id, password, device_id): """Perform the UI-Auth to delete a device""" - request, channel = self.make_request( + channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token ) self.assertEquals(channel.code, 401, channel.result) @@ -262,7 +254,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "session": channel.json_body["session"], } - request, channel = self.make_request( + channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token, @@ -278,26 +270,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): access_token = self.login("kermit", "monkey") # we should now be able to make requests with the access token - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], True) # Now try to hard logout this session - request, channel = self.make_request( - b"POST", "/logout", access_token=access_token - ) + channel = self.make_request(b"POST", "/logout", access_token=access_token) self.assertEquals(channel.result["code"], b"200", channel.result) @override_config({"session_lifetime": "24h"}) @@ -308,26 +294,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): access_token = self.login("kermit", "monkey") # we should now be able to make requests with the access token - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], True) # Now try to hard log out all of the user's sessions - request, channel = self.make_request( - b"POST", "/logout/all", access_token=access_token - ) + channel = self.make_request(b"POST", "/logout/all", access_token=access_token) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -402,7 +382,7 @@ class CASTestCase(unittest.HomeserverTestCase): cas_ticket_url = urllib.parse.urlunparse(url_parts) # Get Synapse to call the fake CAS and serve the template. - request, channel = self.make_request("GET", cas_ticket_url) + channel = self.make_request("GET", cas_ticket_url) # Test that the response is HTML. self.assertEqual(channel.code, 200) @@ -446,7 +426,7 @@ class CASTestCase(unittest.HomeserverTestCase): ) # Get Synapse to call the fake CAS and serve the template. - request, channel = self.make_request("GET", cas_ticket_url) + channel = self.make_request("GET", cas_ticket_url) self.assertEqual(channel.code, 302) location_headers = channel.headers.getRawHeaders("Location") @@ -472,7 +452,7 @@ class CASTestCase(unittest.HomeserverTestCase): ) # Get Synapse to call the fake CAS and serve the template. - request, channel = self.make_request("GET", cas_ticket_url) + channel = self.make_request("GET", cas_ticket_url) # Because the user is deactivated they are served an error template. self.assertEqual(channel.code, 403) @@ -495,14 +475,18 @@ class JWTTestCase(unittest.HomeserverTestCase): self.hs.config.jwt_algorithm = self.jwt_algorithm return self.hs - def jwt_encode(self, token, secret=jwt_secret): - return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii") + def jwt_encode(self, token: str, secret: str = jwt_secret) -> str: + # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. + result = jwt.encode(token, secret, self.jwt_algorithm) + if isinstance(result, bytes): + return result.decode("ascii") + return result def jwt_login(self, *args): params = json.dumps( {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} ) - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) return channel def test_login_jwt_valid_registered(self): @@ -634,7 +618,7 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_no_token(self): params = json.dumps({"type": "org.matrix.login.jwt"}) - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") @@ -700,14 +684,18 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): self.hs.config.jwt_algorithm = "RS256" return self.hs - def jwt_encode(self, token, secret=jwt_privatekey): - return jwt.encode(token, secret, "RS256").decode("ascii") + def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str: + # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. + result = jwt.encode(token, secret, "RS256") + if isinstance(result, bytes): + return result.decode("ascii") + return result def jwt_login(self, *args): params = json.dumps( {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} ) - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) return channel def test_login_jwt_valid(self): @@ -735,7 +723,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): ] def register_as_user(self, username): - request, channel = self.make_request( + self.make_request( b"POST", "/_matrix/client/r0/register?access_token=%s" % (self.service.token,), {"username": username}, @@ -784,7 +772,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": AS_USER}, } - request, channel = self.make_request( + channel = self.make_request( b"POST", LOGIN_URL, params, access_token=self.service.token ) @@ -799,7 +787,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": self.service.sender}, } - request, channel = self.make_request( + channel = self.make_request( b"POST", LOGIN_URL, params, access_token=self.service.token ) @@ -814,7 +802,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": "fibble_wibble"}, } - request, channel = self.make_request( + channel = self.make_request( b"POST", LOGIN_URL, params, access_token=self.service.token ) @@ -829,7 +817,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": AS_USER}, } - request, channel = self.make_request( + channel = self.make_request( b"POST", LOGIN_URL, params, access_token=self.another_service.token ) @@ -845,6 +833,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": AS_USER}, } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.result["code"], b"401", channel.result) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 11cd8efe21..94a5154834 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -53,7 +53,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): self.hs.config.use_presence = True body = {"presence": "here", "status_msg": "beep boop"} - request, channel = self.make_request( + channel = self.make_request( "PUT", "/presence/%s/status" % (self.user_id,), body ) @@ -68,7 +68,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): self.hs.config.use_presence = False body = {"presence": "here", "status_msg": "beep boop"} - request, channel = self.make_request( + channel = self.make_request( "PUT", "/presence/%s/status" % (self.user_id,), body ) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 2a3b483eaf..e59fa70baa 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -189,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.owner_tok = self.login("owner", "pass") def test_set_displayname(self): - request, channel = self.make_request( + channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), content=json.dumps({"displayname": "test"}), @@ -202,7 +202,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): def test_set_displayname_too_long(self): """Attempts to set a stupid displayname should get a 400""" - request, channel = self.make_request( + channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), content=json.dumps({"displayname": "test" * 100}), @@ -214,9 +214,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual(res, "owner") def get_displayname(self): - request, channel = self.make_request( - "GET", "/profile/%s/displayname" % (self.owner,) - ) + channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,)) self.assertEqual(channel.code, 200, channel.result) return channel.json_body["displayname"] @@ -278,7 +276,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): ) def request_profile(self, expected_code, url_suffix="", access_token=None): - request, channel = self.make_request( + channel = self.make_request( "GET", self.profile_url + url_suffix, access_token=access_token ) self.assertEqual(channel.code, expected_code, channel.result) @@ -320,19 +318,19 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): """Tests that a user can lookup their own profile without having to be in a room if 'require_auth_for_profile_requests' is set to true in the server's config. """ - request, channel = self.make_request( + channel = self.make_request( "GET", "/profile/" + self.requester, access_token=self.requester_tok ) self.assertEqual(channel.code, 200, channel.result) - request, channel = self.make_request( + channel = self.make_request( "GET", "/profile/" + self.requester + "/displayname", access_token=self.requester_tok, ) self.assertEqual(channel.code, 200, channel.result) - request, channel = self.make_request( + channel = self.make_request( "GET", "/profile/" + self.requester + "/avatar_url", access_token=self.requester_tok, diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py index 7add5523c8..2bc512d75e 100644 --- a/tests/rest/client/v1/test_push_rule_attrs.py +++ b/tests/rest/client/v1/test_push_rule_attrs.py @@ -45,13 +45,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) self.assertEqual(channel.code, 200) # GET enabled for that new rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 200) @@ -74,13 +74,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) self.assertEqual(channel.code, 200) # disable the rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/enabled", {"enabled": False}, @@ -89,26 +89,26 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) # check rule disabled - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], False) # DELETE the rule - request, channel = self.make_request( + channel = self.make_request( "DELETE", "/pushrules/global/override/best.friend", access_token=token ) self.assertEqual(channel.code, 200) # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) self.assertEqual(channel.code, 200) # GET enabled for that new rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 200) @@ -130,13 +130,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) self.assertEqual(channel.code, 200) # disable the rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/enabled", {"enabled": False}, @@ -145,14 +145,14 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) # check rule disabled - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], False) # re-enable the rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/enabled", {"enabled": True}, @@ -161,7 +161,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) # check rule enabled - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 200) @@ -182,32 +182,32 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) self.assertEqual(channel.code, 200) # GET enabled for that new rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 200) # DELETE the rule - request, channel = self.make_request( + channel = self.make_request( "DELETE", "/pushrules/global/override/best.friend", access_token=token ) self.assertEqual(channel.code, 200) # check 404 for deleted rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 404) @@ -221,7 +221,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token ) self.assertEqual(channel.code, 404) @@ -235,7 +235,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # enable & check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/enabled", {"enabled": True}, @@ -252,7 +252,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # enable & check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/.m.muahahah/enabled", {"enabled": True}, @@ -276,13 +276,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) self.assertEqual(channel.code, 200) # GET actions for that new rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/actions", access_token=token ) self.assertEqual(channel.code, 200) @@ -305,13 +305,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) self.assertEqual(channel.code, 200) # change the rule actions - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/actions", {"actions": ["dont_notify"]}, @@ -320,7 +320,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200) # GET actions for that new rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/actions", access_token=token ) self.assertEqual(channel.code, 200) @@ -341,26 +341,26 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) self.assertEqual(channel.code, 200) # DELETE the rule - request, channel = self.make_request( + channel = self.make_request( "DELETE", "/pushrules/global/override/best.friend", access_token=token ) self.assertEqual(channel.code, 200) # check 404 for deleted rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) self.assertEqual(channel.code, 404) @@ -374,7 +374,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token ) self.assertEqual(channel.code, 404) @@ -388,7 +388,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # enable & check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/actions", {"actions": ["dont_notify"]}, @@ -405,7 +405,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # enable & check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/.m.muahahah/actions", {"actions": ["dont_notify"]}, diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 55d872f0ee..6105eac47c 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -84,13 +84,13 @@ class RoomPermissionsTestCase(RoomBase): self.created_rmid_msg_path = ( "rooms/%s/send/m.room.message/a1" % (self.created_rmid) ).encode("ascii") - request, channel = self.make_request( + channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) self.assertEquals(200, channel.code, channel.result) # set topic for public room - request, channel = self.make_request( + channel = self.make_request( "PUT", ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', @@ -112,7 +112,7 @@ class RoomPermissionsTestCase(RoomBase): ) # send message in uncreated room, expect 403 - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, @@ -120,24 +120,24 @@ class RoomPermissionsTestCase(RoomBase): self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 - request, channel = self.make_request("PUT", send_msg_path(), msg_content) + channel = self.make_request("PUT", send_msg_path(), msg_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - request, channel = self.make_request("PUT", send_msg_path(), msg_content) + channel = self.make_request("PUT", send_msg_path(), msg_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) - request, channel = self.make_request("PUT", send_msg_path(), msg_content) + channel = self.make_request("PUT", send_msg_path(), msg_content) self.assertEquals(200, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) - request, channel = self.make_request("PUT", send_msg_path(), msg_content) + channel = self.make_request("PUT", send_msg_path(), msg_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_topic_perms(self): @@ -145,30 +145,30 @@ class RoomPermissionsTestCase(RoomBase): topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid # set/get topic in uncreated room, expect 403 - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) self.assertEquals(403, channel.code, msg=channel.result["body"]) - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) self.assertEquals(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 - request, channel = self.make_request("PUT", topic_path, topic_content) + channel = self.make_request("PUT", topic_path, topic_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("GET", topic_path) + channel = self.make_request("GET", topic_path) self.assertEquals(403, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - request, channel = self.make_request("PUT", topic_path, topic_content) + channel = self.make_request("PUT", topic_path, topic_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 - request, channel = self.make_request("GET", topic_path) + channel = self.make_request("GET", topic_path) self.assertEquals(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 @@ -176,29 +176,29 @@ class RoomPermissionsTestCase(RoomBase): # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id - request, channel = self.make_request("PUT", topic_path, topic_content) + channel = self.make_request("PUT", topic_path, topic_content) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id - request, channel = self.make_request("GET", topic_path) + channel = self.make_request("GET", topic_path) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) - request, channel = self.make_request("PUT", topic_path, topic_content) + channel = self.make_request("PUT", topic_path, topic_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("GET", topic_path) + channel = self.make_request("GET", topic_path) self.assertEquals(200, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) self.assertEquals(403, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, @@ -208,7 +208,7 @@ class RoomPermissionsTestCase(RoomBase): def _test_get_membership(self, room=None, members=[], expect_code=None): for member in members: path = "/rooms/%s/state/m.room.member/%s" % (room, member) - request, channel = self.make_request("GET", path) + channel = self.make_request("GET", path) self.assertEquals(expect_code, channel.code) def test_membership_basic_room_perms(self): @@ -380,16 +380,16 @@ class RoomsMemberListTestCase(RoomBase): def test_get_member_list(self): room_id = self.helper.create_room_as(self.user_id) - request, channel = self.make_request("GET", "/rooms/%s/members" % room_id) + channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEquals(200, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self): - request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") + channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self): room_id = self.helper.create_room_as("@some_other_guy:red") - request, channel = self.make_request("GET", "/rooms/%s/members" % room_id) + channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self): @@ -398,17 +398,17 @@ class RoomsMemberListTestCase(RoomBase): room_path = "/rooms/%s/members" % room_id self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. - request, channel = self.make_request("GET", room_path) + channel = self.make_request("GET", room_path) self.assertEquals(403, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined - request, channel = self.make_request("GET", room_path) + channel = self.make_request("GET", room_path) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left - request, channel = self.make_request("GET", room_path) + channel = self.make_request("GET", room_path) self.assertEquals(200, channel.code, msg=channel.result["body"]) @@ -419,30 +419,26 @@ class RoomsCreateTestCase(RoomBase): def test_post_room_no_keys(self): # POST with no config keys, expect new room id - request, channel = self.make_request("POST", "/createRoom", "{}") + channel = self.make_request("POST", "/createRoom", "{}") self.assertEquals(200, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) def test_post_room_visibility_key(self): # POST with visibility config key, expect new room id - request, channel = self.make_request( - "POST", "/createRoom", b'{"visibility":"private"}' - ) + channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_custom_key(self): # POST with custom config keys, expect new room id - request, channel = self.make_request( - "POST", "/createRoom", b'{"custom":"stuff"}' - ) + channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_known_and_unknown_keys(self): # POST with custom + known config keys, expect new room id - request, channel = self.make_request( + channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) self.assertEquals(200, channel.code) @@ -450,16 +446,16 @@ class RoomsCreateTestCase(RoomBase): def test_post_room_invalid_content(self): # POST with invalid content / paths, expect 400 - request, channel = self.make_request("POST", "/createRoom", b'{"visibili') + channel = self.make_request("POST", "/createRoom", b'{"visibili') self.assertEquals(400, channel.code) - request, channel = self.make_request("POST", "/createRoom", b'["hello"]') + channel = self.make_request("POST", "/createRoom", b'["hello"]') self.assertEquals(400, channel.code) def test_post_room_invitees_invalid_mxid(self): # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 # Note the trailing space in the MXID here! - request, channel = self.make_request( + channel = self.make_request( "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' ) self.assertEquals(400, channel.code) @@ -477,54 +473,54 @@ class RoomTopicTestCase(RoomBase): def test_invalid_puts(self): # missing keys or invalid json - request, channel = self.make_request("PUT", self.path, "{}") + channel = self.make_request("PUT", self.path, "{}") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}') + channel = self.make_request("PUT", self.path, '{"_name":"bo"}') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, '{"nao') + channel = self.make_request("PUT", self.path, '{"nao') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, "text only") + channel = self.make_request("PUT", self.path, "text only") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, "") + channel = self.make_request("PUT", self.path, "") self.assertEquals(400, channel.code, msg=channel.result["body"]) # valid key, wrong type content = '{"topic":["Topic name"]}' - request, channel = self.make_request("PUT", self.path, content) + channel = self.make_request("PUT", self.path, content) self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_topic(self): # nothing should be there - request, channel = self.make_request("GET", self.path) + channel = self.make_request("GET", self.path) self.assertEquals(404, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' - request, channel = self.make_request("PUT", self.path, content) + channel = self.make_request("PUT", self.path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) # valid get - request, channel = self.make_request("GET", self.path) + channel = self.make_request("GET", self.path) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) def test_rooms_topic_with_extra_keys(self): # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' - request, channel = self.make_request("PUT", self.path, content) + channel = self.make_request("PUT", self.path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) # valid get - request, channel = self.make_request("GET", self.path) + channel = self.make_request("GET", self.path) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) @@ -540,24 +536,22 @@ class RoomMemberStateTestCase(RoomBase): def test_invalid_puts(self): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json - request, channel = self.make_request("PUT", path, "{}") + channel = self.make_request("PUT", path, "{}") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, '{"_name":"bo"}') + channel = self.make_request("PUT", path, '{"_name":"bo"}') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, '{"nao') + channel = self.make_request("PUT", path, '{"nao') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request( - "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]' - ) + channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, "text only") + channel = self.make_request("PUT", path, "text only") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, "") + channel = self.make_request("PUT", path, "") self.assertEquals(400, channel.code, msg=channel.result["body"]) # valid keys, wrong types @@ -566,7 +560,7 @@ class RoomMemberStateTestCase(RoomBase): Membership.JOIN, Membership.LEAVE, ) - request, channel = self.make_request("PUT", path, content.encode("ascii")) + channel = self.make_request("PUT", path, content.encode("ascii")) self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_members_self(self): @@ -577,10 +571,10 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN - request, channel = self.make_request("PUT", path, content.encode("ascii")) + channel = self.make_request("PUT", path, content.encode("ascii")) self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("GET", path, None) + channel = self.make_request("GET", path, None) self.assertEquals(200, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} @@ -595,10 +589,10 @@ class RoomMemberStateTestCase(RoomBase): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE - request, channel = self.make_request("PUT", path, content) + channel = self.make_request("PUT", path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("GET", path, None) + channel = self.make_request("GET", path, None) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assertEquals(json.loads(content), channel.json_body) @@ -614,10 +608,10 @@ class RoomMemberStateTestCase(RoomBase): Membership.INVITE, "Join us!", ) - request, channel = self.make_request("PUT", path, content) + channel = self.make_request("PUT", path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("GET", path, None) + channel = self.make_request("GET", path, None) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assertEquals(json.loads(content), channel.json_body) @@ -668,7 +662,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # Update the display name for the user. path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id - request, channel = self.make_request("PUT", path, {"displayname": "John Doe"}) + channel = self.make_request("PUT", path, {"displayname": "John Doe"}) self.assertEquals(channel.code, 200, channel.json_body) # Check that all the rooms have been sent a profile update into. @@ -678,7 +672,7 @@ class RoomJoinRatelimitTestCase(RoomBase): self.user_id, ) - request, channel = self.make_request("GET", path) + channel = self.make_request("GET", path) self.assertEquals(channel.code, 200) self.assertIn("displayname", channel.json_body) @@ -702,7 +696,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # Make sure we send more requests than the rate-limiting config would allow # if all of these requests ended up joining the user to a room. for i in range(4): - request, channel = self.make_request("POST", path % room_id, {}) + channel = self.make_request("POST", path % room_id, {}) self.assertEquals(channel.code, 200) @unittest.override_config( @@ -731,42 +725,40 @@ class RoomMessagesTestCase(RoomBase): def test_invalid_puts(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json - request, channel = self.make_request("PUT", path, b"{}") + channel = self.make_request("PUT", path, b"{}") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b'{"_name":"bo"}') + channel = self.make_request("PUT", path, b'{"_name":"bo"}') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b'{"nao') + channel = self.make_request("PUT", path, b'{"nao') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request( - "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]' - ) + channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b"text only") + channel = self.make_request("PUT", path, b"text only") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b"") + channel = self.make_request("PUT", path, b"") self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_messages_sent(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = b'{"body":"test","msgtype":{"type":"a"}}' - request, channel = self.make_request("PUT", path, content) + channel = self.make_request("PUT", path, content) self.assertEquals(400, channel.code, msg=channel.result["body"]) # custom message types content = b'{"body":"test","msgtype":"test.custom.text"}' - request, channel = self.make_request("PUT", path, content) + channel = self.make_request("PUT", path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = b'{"body":"test2","msgtype":"m.text"}' - request, channel = self.make_request("PUT", path, content) + channel = self.make_request("PUT", path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) @@ -780,9 +772,7 @@ class RoomInitialSyncTestCase(RoomBase): self.room_id = self.helper.create_room_as(self.user_id) def test_initial_sync(self): - request, channel = self.make_request( - "GET", "/rooms/%s/initialSync" % self.room_id - ) + channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) self.assertEquals(200, channel.code) self.assertEquals(self.room_id, channel.json_body["room_id"]) @@ -823,7 +813,7 @@ class RoomMessageListTestCase(RoomBase): def test_topo_token_is_accepted(self): token = "t1-0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) self.assertEquals(200, channel.code) @@ -834,7 +824,7 @@ class RoomMessageListTestCase(RoomBase): def test_stream_token_is_accepted_for_fwd_pagianation(self): token = "s0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) self.assertEquals(200, channel.code) @@ -867,7 +857,7 @@ class RoomMessageListTestCase(RoomBase): self.helper.send(self.room_id, "message 3") # Check that we get the first and second message when querying /messages. - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" % ( @@ -895,7 +885,7 @@ class RoomMessageListTestCase(RoomBase): # Check that we only get the second message through /message now that the first # has been purged. - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" % ( @@ -912,7 +902,7 @@ class RoomMessageListTestCase(RoomBase): # Check that we get no event, but also no error, when querying /messages with # the token that was pointing at the first event, because we don't have it # anymore. - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" % ( @@ -971,7 +961,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): self.helper.send(self.room, body="Hi!", tok=self.other_access_token) self.helper.send(self.room, body="There!", tok=self.other_access_token) - request, channel = self.make_request( + channel = self.make_request( "POST", "/search?access_token=%s" % (self.access_token,), { @@ -1000,7 +990,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): self.helper.send(self.room, body="Hi!", tok=self.other_access_token) self.helper.send(self.room, body="There!", tok=self.other_access_token) - request, channel = self.make_request( + channel = self.make_request( "POST", "/search?access_token=%s" % (self.access_token,), { @@ -1048,14 +1038,14 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): return self.hs def test_restricted_no_auth(self): - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) self.assertEqual(channel.code, 401, channel.result) def test_restricted_auth(self): self.register_user("user", "pass") tok = self.login("user", "pass") - request, channel = self.make_request("GET", self.url, access_token=tok) + channel = self.make_request("GET", self.url, access_token=tok) self.assertEqual(channel.code, 200, channel.result) @@ -1083,7 +1073,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): self.displayname = "test user" data = {"displayname": self.displayname} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,), request_data, @@ -1096,7 +1086,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): def test_per_room_profile_forbidden(self): data = {"membership": "join", "displayname": "other test user"} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id), @@ -1106,7 +1096,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) event_id = channel.json_body["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, @@ -1139,7 +1129,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): def test_join_reason(self): reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/join".format(self.room_id), content={"reason": reason}, @@ -1153,7 +1143,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), content={"reason": reason}, @@ -1167,7 +1157,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/kick".format(self.room_id), content={"reason": reason, "user_id": self.second_user_id}, @@ -1181,7 +1171,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/ban".format(self.room_id), content={"reason": reason, "user_id": self.second_user_id}, @@ -1193,7 +1183,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): def test_unban_reason(self): reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/unban".format(self.room_id), content={"reason": reason, "user_id": self.second_user_id}, @@ -1205,7 +1195,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): def test_invite_reason(self): reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/invite".format(self.room_id), content={"reason": reason, "user_id": self.second_user_id}, @@ -1224,7 +1214,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): ) reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), content={"reason": reason}, @@ -1235,7 +1225,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self._check_for_reason(reason) def _check_for_reason(self, reason): - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( self.room_id, self.second_user_id @@ -1284,7 +1274,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by a label on a /context request.""" event_id = self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), @@ -1314,7 +1304,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by the absence of a label on a /context request.""" event_id = self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), @@ -1349,7 +1339,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): """ event_id = self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), @@ -1377,7 +1367,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() token = "s0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), @@ -1394,7 +1384,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() token = "s0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), @@ -1417,7 +1407,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() token = "s0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" % ( @@ -1448,7 +1438,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "POST", "/search?access_token=%s" % self.tok, request_data ) @@ -1483,7 +1473,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "POST", "/search?access_token=%s" % self.tok, request_data ) @@ -1530,7 +1520,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "POST", "/search?access_token=%s" % self.tok, request_data ) @@ -1651,7 +1641,7 @@ class ContextTestCase(unittest.HomeserverTestCase): # Check that we can still see the messages before the erasure request. - request, channel = self.make_request( + channel = self.make_request( "GET", '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' % (self.room_id, event_id), @@ -1715,7 +1705,7 @@ class ContextTestCase(unittest.HomeserverTestCase): # Check that a user that joined the room after the erasure request can't see # the messages anymore. - request, channel = self.make_request( + channel = self.make_request( "GET", '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' % (self.room_id, event_id), @@ -1805,7 +1795,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict: """Calls the endpoint under test. returns the json response object.""" - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/org.matrix.msc2432/rooms/%s/aliases" % (self.room_id,), @@ -1826,7 +1816,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): data = {"room_id": self.room_id} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok ) self.assertEqual(channel.code, expected_code, channel.result) @@ -1856,14 +1846,14 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): data = {"room_id": self.room_id} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok ) self.assertEqual(channel.code, expected_code, channel.result) def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: """Calls the endpoint under test. returns the json response object.""" - request, channel = self.make_request( + channel = self.make_request( "GET", "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), access_token=self.room_owner_tok, @@ -1875,7 +1865,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: """Calls the endpoint under test. returns the json response object.""" - request, channel = self.make_request( + channel = self.make_request( "PUT", "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), json.dumps(content), diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index ae0207366b..38c51525a3 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -94,7 +94,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id, user="@jim:red") def test_set_typing(self): - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', @@ -117,7 +117,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): ) def test_set_not_typing(self): - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": false}', @@ -125,7 +125,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code) def test_typing_timeout(self): - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', @@ -138,7 +138,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 2) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 5a18af8d34..dbc27893b5 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -81,7 +81,7 @@ class RestHelper: if tok: path = path + "?access_token=%s" % tok - _, channel = make_request( + channel = make_request( self.hs.get_reactor(), self.site, "POST", @@ -157,7 +157,7 @@ class RestHelper: data = {"membership": membership} data.update(extra_data) - _, channel = make_request( + channel = make_request( self.hs.get_reactor(), self.site, "PUT", @@ -192,7 +192,7 @@ class RestHelper: if tok: path = path + "?access_token=%s" % tok - _, channel = make_request( + channel = make_request( self.hs.get_reactor(), self.site, "PUT", @@ -248,9 +248,7 @@ class RestHelper: if body is not None: content = json.dumps(body).encode("utf8") - _, channel = make_request( - self.hs.get_reactor(), self.site, method, path, content - ) + channel = make_request(self.hs.get_reactor(), self.site, method, path, content) assert int(channel.result["code"]) == expect_code, ( "Expected: %d, got: %d, resp: %r" @@ -333,7 +331,7 @@ class RestHelper: """ image_length = len(image_data) path = "/_matrix/media/r0/upload?filename=%s" % (filename,) - _, channel = make_request( + channel = make_request( self.hs.get_reactor(), FakeSite(resource), "POST", @@ -366,7 +364,7 @@ class RestHelper: client_redirect_url = "https://x" # first hit the redirect url (which will issue a cookie and state) - _, channel = make_request( + channel = make_request( self.hs.get_reactor(), self.site, "GET", @@ -411,7 +409,7 @@ class RestHelper: with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): # now hit the callback URI with the right params and a made-up code - _, channel = make_request( + channel = make_request( self.hs.get_reactor(), self.site, "GET", @@ -434,7 +432,7 @@ class RestHelper: # finally, submit the matrix login token to the login API, which gives us our # matrix access token and device id. - _, channel = make_request( + channel = make_request( self.hs.get_reactor(), self.site, "POST", @@ -445,7 +443,7 @@ class RestHelper: return channel.json_body -# an 'oidc_config' suitable for login_with_oidc. +# an 'oidc_config' suitable for login_via_oidc. TEST_OIDC_CONFIG = { "enabled": True, "discover": False, diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 2ac1ecb7d3..cb87b80e33 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -19,7 +19,6 @@ import os import re from email.parser import Parser from typing import Optional -from urllib.parse import urlencode import pkg_resources @@ -241,7 +240,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(session_id) def _request_token(self, email, client_secret): - request, channel = self.make_request( + channel = self.make_request( "POST", b"account/password/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, @@ -255,7 +254,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): path = link.replace("https://example.com", "") # Load the password reset confirmation page - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(self.submit_token_resource), "GET", @@ -268,20 +267,13 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Now POST to the same endpoint, mimicking the same behaviour as clicking the # password reset confirm button - # Send arguments as url-encoded form data, matching the template's behaviour - form_args = [] - for key, value_list in request.args.items(): - for value in value_list: - arg = (key, value) - form_args.append(arg) - # Confirm the password reset - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(self.submit_token_resource), "POST", path, - content=urlencode(form_args).encode("utf8"), + content=b"", shorthand=False, content_is_form=True, ) @@ -310,7 +302,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): def _reset_password( self, new_password, session_id, client_secret, expected_code=200 ): - request, channel = self.make_request( + channel = self.make_request( "POST", b"account/password", { @@ -352,8 +344,8 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) # Check that this access token has been invalidated. - request, channel = self.make_request("GET", "account/whoami") - self.assertEqual(request.code, 401) + channel = self.make_request("GET", "account/whoami") + self.assertEqual(channel.code, 401) def test_pending_invites(self): """Tests that deactivating a user rejects every pending invite for them.""" @@ -407,10 +399,10 @@ class DeactivateTestCase(unittest.HomeserverTestCase): "erase": False, } ) - request, channel = self.make_request( + channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): @@ -530,7 +522,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self._validate_token(link) - request, channel = self.make_request( + channel = self.make_request( "POST", b"/_matrix/client/unstable/account/3pid/add", { @@ -548,7 +540,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) @@ -569,7 +561,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) ) - request, channel = self.make_request( + channel = self.make_request( "POST", b"account/3pid/delete", {"medium": "email", "address": self.email}, @@ -578,7 +570,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) @@ -601,7 +593,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) ) - request, channel = self.make_request( + channel = self.make_request( "POST", b"account/3pid/delete", {"medium": "email", "address": self.email}, @@ -612,7 +604,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) @@ -629,7 +621,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEquals(len(self.email_attempts), 1) # Attempt to add email without clicking the link - request, channel = self.make_request( + channel = self.make_request( "POST", b"/_matrix/client/unstable/account/3pid/add", { @@ -647,7 +639,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) @@ -662,7 +654,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): session_id = "weasle" # Attempt to add email without even requesting an email - request, channel = self.make_request( + channel = self.make_request( "POST", b"/_matrix/client/unstable/account/3pid/add", { @@ -680,7 +672,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) @@ -784,9 +776,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): if next_link: body["next_link"] = next_link - request, channel = self.make_request( - "POST", b"account/3pid/email/requestToken", body, - ) + channel = self.make_request("POST", b"account/3pid/email/requestToken", body,) self.assertEquals(expect_code, channel.code, channel.result) return channel.json_body.get("sid") @@ -794,7 +784,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def _request_token_invalid_email( self, email, expected_errcode, expected_error, client_secret="foobar", ): - request, channel = self.make_request( + channel = self.make_request( "POST", b"account/3pid/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, @@ -807,7 +797,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Remove the host path = link.replace("https://example.com", "") - request, channel = self.make_request("GET", path, shorthand=False) + channel = self.make_request("GET", path, shorthand=False) self.assertEquals(200, channel.code, channel.result) def _get_link_from_email(self): @@ -841,7 +831,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self._validate_token(link) - request, channel = self.make_request( + channel = self.make_request( "POST", b"/_matrix/client/unstable/account/3pid/add", { @@ -859,7 +849,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index ac67a9de29..ac66a4e0b7 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -13,14 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Union from twisted.internet.defer import succeed import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.http.site import SynapseRequest from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import auth, devices, register from synapse.rest.oidc import OIDCResource @@ -67,11 +66,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase): def register(self, expected_response: int, body: JsonDict) -> FakeChannel: """Make a register request.""" - request, channel = self.make_request( - "POST", "register", body - ) # type: SynapseRequest, FakeChannel + channel = self.make_request("POST", "register", body) - self.assertEqual(request.code, expected_response) + self.assertEqual(channel.code, expected_response) return channel def recaptcha( @@ -81,18 +78,18 @@ class FallbackAuthTests(unittest.HomeserverTestCase): if post_session is None: post_session = session - request, channel = self.make_request( + channel = self.make_request( "GET", "auth/m.login.recaptcha/fallback/web?session=" + session - ) # type: SynapseRequest, FakeChannel - self.assertEqual(request.code, 200) + ) + self.assertEqual(channel.code, 200) - request, channel = self.make_request( + channel = self.make_request( "POST", "auth/m.login.recaptcha/fallback/web?session=" + post_session + "&g-recaptcha-response=a", ) - self.assertEqual(request.code, expected_post_response) + self.assertEqual(channel.code, expected_post_response) # The recaptcha handler is called with the response given attempts = self.recaptcha_checker.recaptcha_attempts @@ -180,13 +177,8 @@ class UIAuthTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) - self.user_tok = self.login("test", self.user_pass) - - def get_device_ids(self, access_token: str) -> List[str]: - # Get the list of devices so one can be deleted. - _, channel = self.make_request("GET", "devices", access_token=access_token,) - self.assertEqual(channel.code, 200) - return [d["device_id"] for d in channel.json_body["devices"]] + self.device_id = "dev1" + self.user_tok = self.login("test", self.user_pass, self.device_id) def delete_device( self, @@ -196,12 +188,12 @@ class UIAuthTests(unittest.HomeserverTestCase): body: Union[bytes, JsonDict] = b"", ) -> FakeChannel: """Delete an individual device.""" - request, channel = self.make_request( + channel = self.make_request( "DELETE", "devices/" + device, body, access_token=access_token, - ) # type: SynapseRequest, FakeChannel + ) # Ensure the response is sane. - self.assertEqual(request.code, expected_response) + self.assertEqual(channel.code, expected_response) return channel @@ -209,12 +201,12 @@ class UIAuthTests(unittest.HomeserverTestCase): """Delete 1 or more devices.""" # Note that this uses the delete_devices endpoint so that we can modify # the payload half-way through some tests. - request, channel = self.make_request( + channel = self.make_request( "POST", "delete_devices", body, access_token=self.user_tok, - ) # type: SynapseRequest, FakeChannel + ) # Ensure the response is sane. - self.assertEqual(request.code, expected_response) + self.assertEqual(channel.code, expected_response) return channel @@ -222,11 +214,9 @@ class UIAuthTests(unittest.HomeserverTestCase): """ Test user interactive authentication outside of registration. """ - device_id = self.get_device_ids(self.user_tok)[0] - # Attempt to delete this device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, device_id, 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) # Grab the session session = channel.json_body["session"] @@ -236,7 +226,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow. self.delete_device( self.user_tok, - device_id, + self.device_id, 200, { "auth": { @@ -255,14 +245,13 @@ class UIAuthTests(unittest.HomeserverTestCase): UIA - check that still works. """ - device_id = self.get_device_ids(self.user_tok)[0] - channel = self.delete_device(self.user_tok, device_id, 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) session = channel.json_body["session"] # Make another request providing the UI auth flow. self.delete_device( self.user_tok, - device_id, + self.device_id, 200, { "auth": { @@ -285,14 +274,11 @@ class UIAuthTests(unittest.HomeserverTestCase): session ID should be rejected. """ # Create a second login. - self.login("test", self.user_pass) - - device_ids = self.get_device_ids(self.user_tok) - self.assertEqual(len(device_ids), 2) + self.login("test", self.user_pass, "dev2") # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_devices(401, {"devices": [device_ids[0]]}) + channel = self.delete_devices(401, {"devices": [self.device_id]}) # Grab the session session = channel.json_body["session"] @@ -304,7 +290,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_devices( 200, { - "devices": [device_ids[1]], + "devices": ["dev2"], "auth": { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": self.user}, @@ -319,14 +305,11 @@ class UIAuthTests(unittest.HomeserverTestCase): The initial requested URI cannot be modified during the user interactive authentication session. """ # Create a second login. - self.login("test", self.user_pass) - - device_ids = self.get_device_ids(self.user_tok) - self.assertEqual(len(device_ids), 2) + self.login("test", self.user_pass, "dev2") # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, device_ids[0], 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) # Grab the session session = channel.json_body["session"] @@ -335,9 +318,11 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow, but try to delete the # second device. This results in an error. + # + # This makes use of the fact that the device ID is embedded into the URL. self.delete_device( self.user_tok, - device_ids[1], + "dev2", 403, { "auth": { @@ -349,6 +334,52 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) + @unittest.override_config({"ui_auth": {"session_timeout": 5 * 1000}}) + def test_can_reuse_session(self): + """ + The session can be reused if configured. + + Compare to test_cannot_change_uri. + """ + # Create a second and third login. + self.login("test", self.user_pass, "dev2") + self.login("test", self.user_pass, "dev3") + + # Attempt to delete a device. This works since the user just logged in. + self.delete_device(self.user_tok, "dev2", 200) + + # Move the clock forward past the validation timeout. + self.reactor.advance(6) + + # Deleting another devices throws the user into UI auth. + channel = self.delete_device(self.user_tok, "dev3", 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + self.user_tok, + "dev3", + 200, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + # Make another request, but try to delete the first device. This works + # due to re-using the previous session. + # + # Note that *no auth* information is provided, not even a session iD! + self.delete_device(self.user_tok, self.device_id, 200) + def test_does_not_offer_password_for_sso_user(self): login_resp = self.helper.login_via_oidc("username") user_tok = login_resp["access_token"] @@ -364,8 +395,7 @@ class UIAuthTests(unittest.HomeserverTestCase): def test_does_not_offer_sso_for_password_user(self): # now call the device deletion API: we should get the option to auth with SSO # and not password. - device_ids = self.get_device_ids(self.user_tok) - channel = self.delete_device(self.user_tok, device_ids[0], 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.password"]}]) @@ -376,8 +406,7 @@ class UIAuthTests(unittest.HomeserverTestCase): login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) - device_ids = self.get_device_ids(self.user_tok) - channel = self.delete_device(self.user_tok, device_ids[0], 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) flows = channel.json_body["flows"] # we have no particular expectations of ordering here diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index 767e126875..e808339fb3 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -36,7 +36,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): return hs def test_check_auth_required(self): - request, channel = self.make_request("GET", self.url) + channel = self.make_request("GET", self.url) self.assertEqual(channel.code, 401) @@ -44,7 +44,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.register_user("user", "pass") access_token = self.login("user", "pass") - request, channel = self.make_request("GET", self.url, access_token=access_token) + channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) @@ -62,7 +62,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): user = self.register_user(localpart, password) access_token = self.login(user, password) - request, channel = self.make_request("GET", self.url, access_token=access_token) + channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) @@ -70,7 +70,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): # Test case where password is handled outside of Synapse self.assertTrue(capabilities["m.change_password"]["enabled"]) self.get_success(self.store.user_set_password_hash(user, None)) - request, channel = self.make_request("GET", self.url, access_token=access_token) + channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index 231d5aefea..f761c44936 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -36,7 +36,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastore() def test_add_filter(self): - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, @@ -49,7 +49,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEquals(filter.result, self.EXAMPLE_FILTER) def test_add_filter_for_other_user(self): - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), self.EXAMPLE_FILTER_JSON, @@ -61,7 +61,7 @@ class FilterTestCase(unittest.HomeserverTestCase): def test_add_filter_non_local_user(self): _is_mine = self.hs.is_mine self.hs.is_mine = lambda target_user: False - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, @@ -79,7 +79,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.reactor.advance(1) filter_id = filter_id.result - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) ) @@ -87,7 +87,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) def test_get_filter_non_existant(self): - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) ) @@ -97,7 +97,7 @@ class FilterTestCase(unittest.HomeserverTestCase): # Currently invalid params do not have an appropriate errcode # in errors.py def test_get_filter_invalid_id(self): - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) ) @@ -105,7 +105,7 @@ class FilterTestCase(unittest.HomeserverTestCase): # No ID also returns an invalid_id error def test_get_filter_no_id(self): - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) ) diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py index ee86b94917..fba34def30 100644 --- a/tests/rest/client/v2_alpha/test_password_policy.py +++ b/tests/rest/client/v2_alpha/test_password_policy.py @@ -70,9 +70,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_get_policy(self): """Tests if the /password_policy endpoint returns the configured policy.""" - request, channel = self.make_request( - "GET", "/_matrix/client/r0/password_policy" - ) + channel = self.make_request("GET", "/_matrix/client/r0/password_policy") self.assertEqual(channel.code, 200, channel.result) self.assertEqual( @@ -89,7 +87,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_too_short(self): request_data = json.dumps({"username": "kermit", "password": "shorty"}) - request, channel = self.make_request("POST", self.register_url, request_data) + channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -98,7 +96,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_no_digit(self): request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) - request, channel = self.make_request("POST", self.register_url, request_data) + channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -107,7 +105,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_no_symbol(self): request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) - request, channel = self.make_request("POST", self.register_url, request_data) + channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -116,7 +114,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_no_uppercase(self): request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) - request, channel = self.make_request("POST", self.register_url, request_data) + channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -125,7 +123,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_no_lowercase(self): request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) - request, channel = self.make_request("POST", self.register_url, request_data) + channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -134,7 +132,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_compliant(self): request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) - request, channel = self.make_request("POST", self.register_url, request_data) + channel = self.make_request("POST", self.register_url, request_data) # Getting a 401 here means the password has passed validation and the server has # responded with a list of registration flows. @@ -160,7 +158,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): }, } ) - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/account/password", request_data, diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index bcb21d0ced..27db4f551e 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -61,7 +61,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.hs.get_datastore().services_cache.append(appservice) request_data = json.dumps({"username": "as_user_kermit"}) - request, channel = self.make_request( + channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) @@ -72,7 +72,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_appservice_registration_invalid(self): self.appservice = None # no application service exists request_data = json.dumps({"username": "kermit"}) - request, channel = self.make_request( + channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) @@ -80,14 +80,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_bad_password(self): request_data = json.dumps({"username": "kermit", "password": 666}) - request, channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, request_data) self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.json_body["error"], "Invalid password") def test_POST_bad_username(self): request_data = json.dumps({"username": 777, "password": "monkey"}) - request, channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, request_data) self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.json_body["error"], "Invalid username") @@ -102,7 +102,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "auth": {"type": LoginType.DUMMY}, } request_data = json.dumps(params) - request, channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, request_data) det_data = { "user_id": user_id, @@ -117,7 +117,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): request_data = json.dumps({"username": "kermit", "password": "monkey"}) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) - request, channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, request_data) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Registration has been disabled") @@ -127,7 +127,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.hs.config.macaroon_secret_key = "test" self.hs.config.allow_guest_access = True - request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} self.assertEquals(channel.result["code"], b"200", channel.result) @@ -136,7 +136,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_disabled_guest_registration(self): self.hs.config.allow_guest_access = False - request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") @@ -145,7 +145,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_ratelimiting_guest(self): for i in range(0, 6): url = self.url + b"?kind=guest" - request, channel = self.make_request(b"POST", url, b"{}") + channel = self.make_request(b"POST", url, b"{}") if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -155,7 +155,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.assertEquals(channel.result["code"], b"200", channel.result) @@ -169,7 +169,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "auth": {"type": LoginType.DUMMY}, } request_data = json.dumps(params) - request, channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, request_data) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -179,12 +179,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.assertEquals(channel.result["code"], b"200", channel.result) def test_advertised_flows(self): - request, channel = self.make_request(b"POST", self.url, b"{}") + channel = self.make_request(b"POST", self.url, b"{}") self.assertEquals(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -207,7 +207,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): } ) def test_advertised_flows_captcha_and_terms_and_3pids(self): - request, channel = self.make_request(b"POST", self.url, b"{}") + channel = self.make_request(b"POST", self.url, b"{}") self.assertEquals(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -239,7 +239,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): } ) def test_advertised_flows_no_msisdn_email_required(self): - request, channel = self.make_request(b"POST", self.url, b"{}") + channel = self.make_request(b"POST", self.url, b"{}") self.assertEquals(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -279,7 +279,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) ) - request, channel = self.make_request( + channel = self.make_request( "POST", b"register/email/requestToken", {"client_secret": "foobar", "email": email, "send_attempt": 1}, @@ -318,13 +318,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request(b"GET", "/sync", access_token=tok) + channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) - request, channel = self.make_request(b"GET", "/sync", access_token=tok) + channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( @@ -346,14 +346,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/account_validity/validity" params = {"user_id": user_id} request_data = json.dumps(params) - request, channel = self.make_request( - b"POST", url, request_data, access_token=admin_tok - ) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEquals(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request(b"GET", "/sync", access_token=tok) + channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) def test_manual_expire(self): @@ -370,14 +368,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): "enable_renewal_emails": False, } request_data = json.dumps(params) - request, channel = self.make_request( - b"POST", url, request_data, access_token=admin_tok - ) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEquals(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request(b"GET", "/sync", access_token=tok) + channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result @@ -397,20 +393,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): "enable_renewal_emails": False, } request_data = json.dumps(params) - request, channel = self.make_request( - b"POST", url, request_data, access_token=admin_tok - ) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEquals(channel.result["code"], b"200", channel.result) # Try to log the user out - request, channel = self.make_request(b"POST", "/logout", access_token=tok) + channel = self.make_request(b"POST", "/logout", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) # Log the user in again (allowed for expired accounts) tok = self.login("kermit", "monkey") # Try to log out all of the user's sessions - request, channel = self.make_request(b"POST", "/logout/all", access_token=tok) + channel = self.make_request(b"POST", "/logout/all", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -484,7 +478,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # retrieve the token from the DB. renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token - request, channel = self.make_request(b"GET", url) + channel = self.make_request(b"GET", url) self.assertEquals(channel.result["code"], b"200", channel.result) # Check that we're getting HTML back. @@ -504,14 +498,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # our access token should be denied from now, otherwise they should # succeed. self.reactor.advance(datetime.timedelta(days=3).total_seconds()) - request, channel = self.make_request(b"GET", "/sync", access_token=tok) + channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) def test_renewal_invalid_token(self): # Hit the renewal endpoint with an invalid token and check that it behaves as # expected, i.e. that it responds with 404 Not Found and the correct HTML. url = "/_matrix/client/unstable/account_validity/renew?token=123" - request, channel = self.make_request(b"GET", url) + channel = self.make_request(b"GET", url) self.assertEquals(channel.result["code"], b"404", channel.result) # Check that we're getting HTML back. @@ -532,7 +526,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.email_attempts = [] (user_id, tok) = self.create_user() - request, channel = self.make_request( + channel = self.make_request( b"POST", "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, @@ -556,10 +550,10 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "erase": False, } ) - request, channel = self.make_request( + channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) self.reactor.advance(datetime.timedelta(days=8).total_seconds()) @@ -607,7 +601,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.email_attempts = [] # Test that we're still able to manually trigger a mail to be sent. - request, channel = self.make_request( + channel = self.make_request( b"POST", "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 6cd4eb6624..bd574077e7 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -60,7 +60,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): event_id = channel.json_body["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, event_id), access_token=self.user_token, @@ -107,7 +107,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" % (self.room, self.parent_id), @@ -152,7 +152,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): if prev_token: from_token = "&from=" + prev_token - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" % (self.room, self.parent_id, from_token), @@ -210,7 +210,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): if prev_token: from_token = "&from=" + prev_token - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" % (self.room, self.parent_id, from_token), @@ -279,7 +279,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): if prev_token: from_token = "&from=" + prev_token - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s" "/aggregations/%s/%s/m.reaction/%s?limit=1%s" @@ -325,7 +325,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") self.assertEquals(200, channel.code, channel.json_body) - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s" % (self.room, self.parent_id), @@ -357,7 +357,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) # Now lets redact one of the 'a' reactions - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), access_token=self.user_token, @@ -365,7 +365,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(200, channel.code, channel.json_body) - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s" % (self.room, self.parent_id), @@ -382,7 +382,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): """Test that aggregations must be annotations. """ - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1" % (self.room, self.parent_id, RelationTypes.REPLACE), @@ -414,7 +414,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) reply_2 = channel.json_body["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, @@ -450,7 +450,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): edit_event_id = channel.json_body["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, @@ -507,7 +507,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(200, channel.code, channel.json_body) - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, @@ -549,7 +549,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) # Check the relation is returned - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" % (self.room, original_event_id), @@ -561,7 +561,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(len(channel.json_body["chunk"]), 1) # Redact the original event - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/redact/%s/%s" % (self.room, original_event_id, "test_relations_redaction_redacts_edits"), @@ -571,7 +571,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) # Try to check for remaining m.replace relations - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" % (self.room, original_event_id), @@ -598,7 +598,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) # Redact the original - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/redact/%s/%s" % ( @@ -612,7 +612,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) # Check that aggregations returns zero - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction" % (self.room, original_event_id), @@ -656,7 +656,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): original_id = parent_id if parent_id else self.parent_id - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" % (self.room, original_id, relation_type, event_type, query), diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py index 562a9c1ba4..116ace1812 100644 --- a/tests/rest/client/v2_alpha/test_shared_rooms.py +++ b/tests/rest/client/v2_alpha/test_shared_rooms.py @@ -17,6 +17,7 @@ from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import shared_rooms from tests import unittest +from tests.server import FakeChannel class UserSharedRoomsTest(unittest.HomeserverTestCase): @@ -40,14 +41,13 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.store = hs.get_datastore() self.handler = hs.get_user_directory_handler() - def _get_shared_rooms(self, token, other_user): - request, channel = self.make_request( + def _get_shared_rooms(self, token, other_user) -> FakeChannel: + return self.make_request( "GET", "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" % other_user, access_token=token, ) - return request, channel def test_shared_room_list_public(self): """ @@ -63,7 +63,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.invite(room, src=u1, targ=u2, tok=u1_token) self.helper.join(room, user=u2, tok=u2_token) - request, channel = self._get_shared_rooms(u1_token, u2) + channel = self._get_shared_rooms(u1_token, u2) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 1) self.assertEquals(channel.json_body["joined"][0], room) @@ -82,7 +82,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.invite(room, src=u1, targ=u2, tok=u1_token) self.helper.join(room, user=u2, tok=u2_token) - request, channel = self._get_shared_rooms(u1_token, u2) + channel = self._get_shared_rooms(u1_token, u2) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 1) self.assertEquals(channel.json_body["joined"][0], room) @@ -104,7 +104,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.join(room_public, user=u2, tok=u2_token) self.helper.join(room_private, user=u1, tok=u1_token) - request, channel = self._get_shared_rooms(u1_token, u2) + channel = self._get_shared_rooms(u1_token, u2) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 2) self.assertTrue(room_public in channel.json_body["joined"]) @@ -125,13 +125,13 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.join(room, user=u2, tok=u2_token) # Assert user directory is not empty - request, channel = self._get_shared_rooms(u1_token, u2) + channel = self._get_shared_rooms(u1_token, u2) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 1) self.assertEquals(channel.json_body["joined"][0], room) self.helper.leave(room, user=u1, tok=u1_token) - request, channel = self._get_shared_rooms(u2_token, u1) + channel = self._get_shared_rooms(u2_token, u1) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 31ac0fccb8..512e36c236 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -35,7 +35,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ] def test_sync_argless(self): - request, channel = self.make_request("GET", "/sync") + channel = self.make_request("GET", "/sync") self.assertEqual(channel.code, 200) self.assertTrue( @@ -55,7 +55,7 @@ class FilterTestCase(unittest.HomeserverTestCase): """ self.hs.config.use_presence = False - request, channel = self.make_request("GET", "/sync") + channel = self.make_request("GET", "/sync") self.assertEqual(channel.code, 200) self.assertTrue( @@ -194,7 +194,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): tok=tok, ) - request, channel = self.make_request( + channel = self.make_request( "GET", "/sync?filter=%s" % sync_filter, access_token=tok ) self.assertEqual(channel.code, 200, channel.result) @@ -245,21 +245,19 @@ class SyncTypingTests(unittest.HomeserverTestCase): self.helper.send(room, body="There!", tok=other_access_token) # Start typing. - request, channel = self.make_request( + channel = self.make_request( "PUT", typing_url % (room, other_user_id, other_access_token), b'{"typing": true, "timeout": 30000}', ) self.assertEquals(200, channel.code) - request, channel = self.make_request( - "GET", "/sync?access_token=%s" % (access_token,) - ) + channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,)) self.assertEquals(200, channel.code) next_batch = channel.json_body["next_batch"] # Stop typing. - request, channel = self.make_request( + channel = self.make_request( "PUT", typing_url % (room, other_user_id, other_access_token), b'{"typing": false}', @@ -267,7 +265,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): self.assertEquals(200, channel.code) # Start typing. - request, channel = self.make_request( + channel = self.make_request( "PUT", typing_url % (room, other_user_id, other_access_token), b'{"typing": true, "timeout": 30000}', @@ -275,9 +273,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): self.assertEquals(200, channel.code) # Should return immediately - request, channel = self.make_request( - "GET", sync_url % (access_token, next_batch) - ) + channel = self.make_request("GET", sync_url % (access_token, next_batch)) self.assertEquals(200, channel.code) next_batch = channel.json_body["next_batch"] @@ -289,9 +285,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): # invalidate the stream token. self.helper.send(room, body="There!", tok=other_access_token) - request, channel = self.make_request( - "GET", sync_url % (access_token, next_batch) - ) + channel = self.make_request("GET", sync_url % (access_token, next_batch)) self.assertEquals(200, channel.code) next_batch = channel.json_body["next_batch"] @@ -299,9 +293,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): # ahead, and therefore it's saying the typing (that we've actually # already seen) is new, since it's got a token above our new, now-reset # stream token. - request, channel = self.make_request( - "GET", sync_url % (access_token, next_batch) - ) + channel = self.make_request("GET", sync_url % (access_token, next_batch)) self.assertEquals(200, channel.code) next_batch = channel.json_body["next_batch"] @@ -383,7 +375,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Send a read receipt to tell the server we've read the latest event. body = json.dumps({"m.read": res["event_id"]}).encode("utf8") - request, channel = self.make_request( + channel = self.make_request( "POST", "/rooms/%s/read_markers" % self.room_id, body, @@ -450,7 +442,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): def _check_unread_count(self, expected_count: True): """Syncs and compares the unread count with the expected value.""" - request, channel = self.make_request( + channel = self.make_request( "GET", self.url % self.next_batch, access_token=self.tok, ) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 6f0677d335..ae2b32b131 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -228,7 +228,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): def _req(self, content_disposition): - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(self.download_resource), "GET", @@ -324,7 +324,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): def _test_thumbnail(self, method, expected_body, expected_found): params = "?width=32&height=32&method=" + method - request, channel = make_request( + channel = make_request( self.reactor, FakeSite(self.thumbnail_resource), "GET", diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 529b6bcded..83d728b4a4 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -113,7 +113,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): def test_cache_returns_correct_type(self): self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://matrix.org", shorthand=False, @@ -138,7 +138,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) # Check the cache returns the correct response - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://matrix.org", shorthand=False ) @@ -154,7 +154,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertNotIn("http://matrix.org", self.preview_url._cache) # Check the database cache returns the correct response - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://matrix.org", shorthand=False ) @@ -175,7 +175,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"</head></html>" ) - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://matrix.org", shorthand=False, @@ -210,7 +210,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"</head></html>" ) - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://matrix.org", shorthand=False, @@ -245,7 +245,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"</head></html>" ) - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://matrix.org", shorthand=False, @@ -278,7 +278,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False, @@ -308,7 +308,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")] - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False ) @@ -329,7 +329,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")] - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False ) @@ -346,7 +346,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ Blacklisted IP addresses, accessed directly, are not spidered. """ - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://192.168.1.1", shorthand=False ) @@ -365,7 +365,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ Blacklisted IP ranges, accessed directly, are not spidered. """ - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://1.1.1.2", shorthand=False ) @@ -385,7 +385,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")] - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False, @@ -422,7 +422,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): (IPv4Address, "10.1.2.3"), ] - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False ) self.assertEqual(channel.code, 502) @@ -442,7 +442,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff") ] - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False ) @@ -463,7 +463,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ self.lookups["example.com"] = [(IPv6Address, "2001:800::1")] - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False ) @@ -480,7 +480,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ OPTIONS returns the OPTIONS. """ - request, channel = self.make_request( + channel = self.make_request( "OPTIONS", "preview_url?url=http://example.com", shorthand=False ) self.assertEqual(channel.code, 200) @@ -493,7 +493,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] # Build and make a request to the server - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://example.com", shorthand=False, @@ -567,7 +567,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"</head></html>" ) - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://twitter.com/matrixdotorg/status/12345", shorthand=False, @@ -632,7 +632,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): } end_content = json.dumps(result).encode("utf-8") - request, channel = self.make_request( + channel = self.make_request( "GET", "preview_url?url=http://twitter.com/matrixdotorg/status/12345", shorthand=False, diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py index 02a46e5fda..32acd93dc1 100644 --- a/tests/rest/test_health.py +++ b/tests/rest/test_health.py @@ -25,7 +25,7 @@ class HealthCheckTests(unittest.HomeserverTestCase): return HealthResource() def test_health(self): - request, channel = self.make_request("GET", "/health", shorthand=False) + channel = self.make_request("GET", "/health", shorthand=False) - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) self.assertEqual(channel.result["body"], b"OK") diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 6a930f4148..14de0921be 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -28,11 +28,11 @@ class WellKnownTests(unittest.HomeserverTestCase): self.hs.config.public_baseurl = "https://tesths" self.hs.config.default_identity_server = "https://testis" - request, channel = self.make_request( + channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) self.assertEqual( channel.json_body, { @@ -44,8 +44,8 @@ class WellKnownTests(unittest.HomeserverTestCase): def test_well_known_no_public_baseurl(self): self.hs.config.public_baseurl = None - request, channel = self.make_request( + channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(request.code, 404) + self.assertEqual(channel.code, 404) diff --git a/tests/server.py b/tests/server.py index 4faf32e335..7d1ad362c4 100644 --- a/tests/server.py +++ b/tests/server.py @@ -174,11 +174,11 @@ def make_request( custom_headers: Optional[ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] ] = None, -): +) -> FakeChannel: """ Make a web request using the given method, path and content, and render it - Returns the Request and the Channel underneath. + Returns the fake Channel object which records the response to the request. Args: site: The twisted Site to use to render the request @@ -202,7 +202,7 @@ def make_request( is finished. Returns: - Tuple[synapse.http.site.SynapseRequest, channel] + channel """ if not isinstance(method, bytes): method = method.encode("ascii") @@ -265,7 +265,7 @@ def make_request( if await_result: channel.await_result() - return req, channel + return channel @implementer(IReactorPluggableNameResolver) diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py index e0a9cd93ac..4dd5a36178 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py @@ -70,7 +70,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): the notice URL + an authentication code. """ # Initial sync, to get the user consent room invite - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/sync", access_token=self.access_token ) self.assertEqual(channel.code, 200) @@ -79,7 +79,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): room_id = list(channel.json_body["rooms"]["invite"].keys())[0] # Join the room - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/" + room_id + "/join", access_token=self.access_token, @@ -87,7 +87,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) # Sync again, to get the message in the room - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/sync", access_token=self.access_token ) self.assertEqual(channel.code, 200) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 9c8027a5b2..fea54464af 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -305,7 +305,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.register_user("user", "password") tok = self.login("user", "password") - request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) + channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) invites = channel.json_body["rooms"]["invite"] self.assertEqual(len(invites), 0, invites) @@ -318,7 +318,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): # Sync again to retrieve the events in the room, so we can check whether this # room has a notice in it. - request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) + channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) # Scan the events in the room to search for a message from the server notices # user. @@ -353,9 +353,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): tok = self.login(localpart, "password") # Sync with the user's token to mark the user as active. - request, channel = self.make_request( - "GET", "/sync?timeout=0", access_token=tok, - ) + channel = self.make_request("GET", "/sync?timeout=0", access_token=tok,) # Also retrieves the list of invites for this user. We don't care about that # one except if we're processing the last user, which should have received an diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 09f4f32a02..77c72834f2 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -88,7 +88,7 @@ class FakeEvent: event_dict = { "auth_events": [(a, {}) for a in auth_events], "prev_events": [(p, {}) for p in prev_events], - "event_id": self.node_id, + "event_id": self.event_id, "sender": self.sender, "type": self.type, "content": self.content, @@ -381,6 +381,61 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) + def test_mainline_sort(self): + """Tests that the mainline ordering works correctly. + """ + + events = [ + FakeEvent( + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} + ), + FakeEvent( + id="PA1", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key="", + content={"users": {ALICE: 100, BOB: 50}}, + ), + FakeEvent( + id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={} + ), + FakeEvent( + id="PA2", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key="", + content={ + "users": {ALICE: 100, BOB: 50}, + "events": {EventTypes.PowerLevels: 100}, + }, + ), + FakeEvent( + id="PB", + sender=BOB, + type=EventTypes.PowerLevels, + state_key="", + content={"users": {ALICE: 100, BOB: 50}}, + ), + FakeEvent( + id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={} + ), + FakeEvent( + id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={} + ), + ] + + edges = [ + ["END", "T3", "PA2", "T2", "PA1", "T1", "START"], + ["END", "T4", "PB", "PA1"], + ] + + # We expect T3 to be picked as the other topics are pointing at older + # power levels. Note that without mainline ordering we'd pick T4 due to + # it being sent *after* T3. + expected_state_ids = ["T3", "PA2"] + + self.do_check(events, edges, expected_state_ids) + def do_check(self, events, edges, expected_state_ids): """Take a list of events and edges and calculate the state of the graph at END, and asserts it matches `expected_state_ids` diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py new file mode 100644 index 0000000000..71210ce606 --- /dev/null +++ b/tests/storage/test_events.py @@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import RoomVersions +from synapse.federation.federation_base import event_from_pdu_json +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.unittest import HomeserverTestCase + + +class ExtremPruneTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.state = self.hs.get_state_handler() + self.persistence = self.hs.get_storage().persistence + self.store = self.hs.get_datastore() + + self.register_user("user", "pass") + self.token = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + "user", room_version=RoomVersions.V6.identifier, tok=self.token + ) + + body = self.helper.send(self.room_id, body="Test", tok=self.token) + local_message_event_id = body["event_id"] + + # Fudge a remote event and persist it. This will be the extremity before + # the gap. + self.remote_event_1 = event_from_pdu_json( + { + "type": EventTypes.Message, + "state_key": "@user:other", + "content": {}, + "room_id": self.room_id, + "sender": "@user:other", + "depth": 5, + "prev_events": [local_message_event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + self.persist_event(self.remote_event_1) + + # Check that the current extremities is the remote event. + self.assert_extremities([self.remote_event_1.event_id]) + + def persist_event(self, event, state=None): + """Persist the event, with optional state + """ + context = self.get_success( + self.state.compute_event_context(event, old_state=state) + ) + self.get_success(self.persistence.persist_event(event, context)) + + def assert_extremities(self, expected_extremities): + """Assert the current extremities for the room + """ + extremities = self.get_success( + self.store.get_prev_events_for_room(self.room_id) + ) + self.assertCountEqual(extremities, expected_extremities) + + def test_prune_gap(self): + """Test that we drop extremities after a gap when we see an event from + the same domain. + """ + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other", + "depth": 50, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([remote_event_2.event_id]) + + def test_do_not_prune_gap_if_state_different(self): + """Test that we don't prune extremities after a gap if the resolved + state is different. + """ + + # Fudge a second event which points to an event we don't have. + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Message, + "state_key": "@user:other", + "content": {}, + "room_id": self.room_id, + "sender": "@user:other", + "depth": 10, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + # Now we persist it with state with a dropped history visibility + # setting. The state resolution across the old and new event will then + # include it, and so the resolved state won't match the new state. + state_before_gap = dict( + self.get_success(self.state.get_current_state(self.room_id)) + ) + state_before_gap.pop(("m.room.history_visibility", "")) + + context = self.get_success( + self.state.compute_event_context( + remote_event_2, old_state=state_before_gap.values() + ) + ) + + self.get_success(self.persistence.persist_event(remote_event_2, context)) + + # Check that we haven't dropped the old extremity. + self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) + + def test_prune_gap_if_old(self): + """Test that we drop extremities after a gap when the previous extremity + is "old" + """ + + # Advance the clock for many days to make the old extremity "old". We + # also set the depth to "lots". + self.reactor.advance(7 * 24 * 60 * 60) + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other2", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other2", + "depth": 10000, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([remote_event_2.event_id]) + + def test_do_not_prune_gap_if_other_server(self): + """Test that we do not drop extremities after a gap when we see an event + from a different domain. + """ + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other2", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other2", + "depth": 10, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) + + def test_prune_gap_if_dummy_remote(self): + """Test that we drop extremities after a gap when the previous extremity + is a local dummy event and only points to remote events. + """ + + body = self.helper.send_event( + self.room_id, type=EventTypes.Dummy, content={}, tok=self.token + ) + local_message_event_id = body["event_id"] + self.assert_extremities([local_message_event_id]) + + # Advance the clock for many days to make the old extremity "old". We + # also set the depth to "lots". + self.reactor.advance(7 * 24 * 60 * 60) + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other2", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other2", + "depth": 10000, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([remote_event_2.event_id]) + + def test_prune_gap_if_dummy_local(self): + """Test that we don't drop extremities after a gap when the previous + extremity is a local dummy event and points to local events. + """ + + body = self.helper.send(self.room_id, body="Test", tok=self.token) + + body = self.helper.send_event( + self.room_id, type=EventTypes.Dummy, content={}, tok=self.token + ) + local_message_event_id = body["event_id"] + self.assert_extremities([local_message_event_id]) + + # Advance the clock for many days to make the old extremity "old". We + # also set the depth to "lots". + self.reactor.advance(7 * 24 * 60 * 60) + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other2", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other2", + "depth": 10000, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([remote_event_2.event_id, local_message_event_id]) + + def test_do_not_prune_gap_if_not_dummy(self): + """Test that we do not drop extremities after a gap when the previous extremity + is not a dummy event. + """ + + body = self.helper.send(self.room_id, body="test", tok=self.token) + local_message_event_id = body["event_id"] + self.assert_extremities([local_message_event_id]) + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other2", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other2", + "depth": 10000, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([local_message_event_id, remote_event_2.event_id]) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index 7e7f1286d9..e9e3bca3bf 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -48,3 +48,10 @@ class DataStoreTestCase(unittest.TestCase): self.assertEquals(1, total) self.assertEquals(self.displayname, users.pop()["displayname"]) + + users, total = yield defer.ensureDeferred( + self.store.get_users_paginate(0, 10, name="BC", guests=False) + ) + + self.assertEquals(1, total) + self.assertEquals(self.displayname, users.pop()["displayname"]) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 738e912468..a6f63f4aaf 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -21,6 +21,8 @@ from tests.utils import setup_test_homeserver ALICE = "@alice:a" BOB = "@bob:b" BOBBY = "@bobby:a" +# The localpart isn't 'Bela' on purpose so we can test looking up display names. +BELA = "@somenickname:a" class UserDirectoryStoreTestCase(unittest.TestCase): @@ -41,6 +43,9 @@ class UserDirectoryStoreTestCase(unittest.TestCase): self.store.update_profile_in_user_dir(BOBBY, "bobby", None) ) yield defer.ensureDeferred( + self.store.update_profile_in_user_dir(BELA, "Bela", None) + ) + yield defer.ensureDeferred( self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)) ) @@ -72,3 +77,21 @@ class UserDirectoryStoreTestCase(unittest.TestCase): ) finally: self.hs.config.user_directory_search_all_users = False + + @defer.inlineCallbacks + def test_search_user_dir_stop_words(self): + """Tests that a user can look up another user by searching for the start if its + display name even if that name happens to be a common English word that would + usually be ignored in full text searches. + """ + self.hs.config.user_directory_search_all_users = True + try: + r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, + ) + finally: + self.hs.config.user_directory_search_all_users = False diff --git a/tests/test_mau.py b/tests/test_mau.py index c5ec6396a7..51660b51d5 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -19,6 +19,7 @@ import json from synapse.api.constants import LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError +from synapse.appservice import ApplicationService from synapse.rest.client.v2_alpha import register, sync from tests import unittest @@ -75,6 +76,45 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.code, 403) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + def test_as_ignores_mau(self): + """Test that application services can still create users when the MAU + limit has been reached. This only works when application service + user ip tracking is disabled. + """ + + # Create and sync so that the MAU counts get updated + token1 = self.create_user("kermit1") + self.do_sync_for_user(token1) + token2 = self.create_user("kermit2") + self.do_sync_for_user(token2) + + # check we're testing what we think we are: there should be two active users + self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2) + + # We've created and activated two users, we shouldn't be able to + # register new users + with self.assertRaises(SynapseError) as cm: + self.create_user("kermit3") + + e = cm.exception + self.assertEqual(e.code, 403) + self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + # Cheekily add an application service that we use to register a new user + # with. + as_token = "foobartoken" + self.store.services_cache.append( + ApplicationService( + token=as_token, + hostname=self.hs.hostname, + id="SomeASID", + sender="@as_sender:test", + namespaces={"users": [{"regex": "@as_*", "exclusive": True}]}, + ) + ) + + self.create_user("as_kermit4", token=as_token) + def test_allowed_after_a_month_mau(self): # Create and sync so that the MAU counts get updated token1 = self.create_user("kermit1") @@ -192,7 +232,7 @@ class TestMauLimit(unittest.HomeserverTestCase): self.reactor.advance(100) self.assertEqual(2, self.successResultOf(count)) - def create_user(self, localpart): + def create_user(self, localpart, token=None): request_data = json.dumps( { "username": localpart, @@ -201,7 +241,9 @@ class TestMauLimit(unittest.HomeserverTestCase): } ) - request, channel = self.make_request("POST", "/register", request_data) + channel = self.make_request( + "POST", "/register", request_data, access_token=token, + ) if channel.code != 200: raise HttpResponseException( @@ -213,7 +255,7 @@ class TestMauLimit(unittest.HomeserverTestCase): return access_token def do_sync_for_user(self, token): - request, channel = self.make_request("GET", "/sync", access_token=token) + channel = self.make_request("GET", "/sync", access_token=token) if channel.code != 200: raise HttpResponseException( diff --git a/tests/test_server.py b/tests/test_server.py index 6b2d2f0401..815da18e65 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -64,11 +64,10 @@ class JsonResourceTests(unittest.TestCase): "test_servlet", ) - request, channel = make_request( + make_request( self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" ) - self.assertEqual(request.args, {b"a": ["\N{SNOWMAN}".encode("utf8")]}) self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"}) def test_callback_direct_exception(self): @@ -85,7 +84,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"500") @@ -109,7 +108,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"500") @@ -127,7 +126,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["error"], "Forbidden!!one!") @@ -149,9 +148,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - _, channel = make_request( - self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar" - ) + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar") self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.json_body["error"], "Unrecognized request") @@ -173,7 +170,7 @@ class JsonResourceTests(unittest.TestCase): ) # The path was registered as GET, but this is a HEAD request. - _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo") + channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) @@ -205,7 +202,7 @@ class OptionsResourceTests(unittest.TestCase): ) # render the request and return the channel - _, channel = make_request(self.reactor, site, method, path, shorthand=False) + channel = make_request(self.reactor, site, method, path, shorthand=False) return channel def test_unknown_options_request(self): @@ -278,7 +275,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"200") body = channel.result["body"] @@ -296,7 +293,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"301") headers = channel.result["headers"] @@ -317,7 +314,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"304") headers = channel.result["headers"] @@ -336,7 +333,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path") + channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path") self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 71580b454d..a743cdc3a9 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -53,7 +53,7 @@ class TermsTestCase(unittest.HomeserverTestCase): def test_ui_auth(self): # Do a UI auth request request_data = json.dumps({"username": "kermit", "password": "monkey"}) - request, channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, request_data) self.assertEquals(channel.result["code"], b"401", channel.result) @@ -96,7 +96,7 @@ class TermsTestCase(unittest.HomeserverTestCase): self.registration_handler.check_username = Mock(return_value=True) - request, channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, request_data) # We don't bother checking that the response is correct - we'll leave that to # other tests. We just want to make sure we're on the right path. @@ -113,7 +113,7 @@ class TermsTestCase(unittest.HomeserverTestCase): }, } ) - request, channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, request_data) # We're interested in getting a response that looks like a successful # registration, not so much that the details are exactly what we want. diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 6873d45eb6..43898d8142 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -22,6 +22,8 @@ import warnings from asyncio import Future from typing import Any, Awaitable, Callable, TypeVar +from mock import Mock + import attr from twisted.python.failure import Failure @@ -87,6 +89,16 @@ def setup_awaitable_errors() -> Callable[[], None]: return cleanup +def simple_async_mock(return_value=None, raises=None) -> Mock: + # AsyncMock is not available in python3.5, this mimics part of its behaviour + async def cb(*args, **kwargs): + if raises: + raise raises + return return_value + + return Mock(side_effect=cb) + + @attr.s class FakeResponse: """A fake twisted.web.IResponse object diff --git a/tests/unittest.py b/tests/unittest.py index 102b0a1f34..af7f752c5a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -20,7 +20,7 @@ import hmac import inspect import logging import time -from typing import Dict, Optional, Tuple, Type, TypeVar, Union, overload +from typing import Dict, Iterable, Optional, Tuple, Type, TypeVar, Union from mock import Mock, patch @@ -372,38 +372,6 @@ class HomeserverTestCase(TestCase): Function to optionally be overridden in subclasses. """ - # Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest - # when the `request` arg isn't given, so we define an explicit override to - # cover that case. - @overload - def make_request( - self, - method: Union[bytes, str], - path: Union[bytes, str], - content: Union[bytes, dict] = b"", - access_token: Optional[str] = None, - shorthand: bool = True, - federation_auth_origin: str = None, - content_is_form: bool = False, - await_result: bool = True, - ) -> Tuple[SynapseRequest, FakeChannel]: - ... - - @overload - def make_request( - self, - method: Union[bytes, str], - path: Union[bytes, str], - content: Union[bytes, dict] = b"", - access_token: Optional[str] = None, - request: Type[T] = SynapseRequest, - shorthand: bool = True, - federation_auth_origin: str = None, - content_is_form: bool = False, - await_result: bool = True, - ) -> Tuple[T, FakeChannel]: - ... - def make_request( self, method: Union[bytes, str], @@ -415,7 +383,10 @@ class HomeserverTestCase(TestCase): federation_auth_origin: str = None, content_is_form: bool = False, await_result: bool = True, - ) -> Tuple[T, FakeChannel]: + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ) -> FakeChannel: """ Create a SynapseRequest at the path using the method and containing the given content. @@ -437,8 +408,10 @@ class HomeserverTestCase(TestCase): true (the default), will pump the test reactor until the the renderer tells the channel the request is finished. + custom_headers: (name, value) pairs to add as request headers + Returns: - Tuple[synapse.http.site.SynapseRequest, channel] + The FakeChannel object which stores the result of the request. """ return make_request( self.reactor, @@ -452,6 +425,7 @@ class HomeserverTestCase(TestCase): federation_auth_origin, content_is_form, await_result, + custom_headers, ) def setup_test_homeserver(self, *args, **kwargs): @@ -568,7 +542,7 @@ class HomeserverTestCase(TestCase): self.hs.config.registration_shared_secret = "shared" # Create the user - request, channel = self.make_request("GET", "/_synapse/admin/v1/register") + channel = self.make_request("GET", "/_synapse/admin/v1/register") self.assertEqual(channel.code, 200, msg=channel.result) nonce = channel.json_body["nonce"] @@ -593,7 +567,7 @@ class HomeserverTestCase(TestCase): "inhibit_login": True, } ) - request, channel = self.make_request( + channel = self.make_request( "POST", "/_synapse/admin/v1/register", body.encode("utf8") ) self.assertEqual(channel.code, 200, channel.json_body) @@ -611,7 +585,7 @@ class HomeserverTestCase(TestCase): if device_id: body["device_id"] = device_id - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) self.assertEqual(channel.code, 200, channel.result) @@ -679,7 +653,7 @@ class HomeserverTestCase(TestCase): """ body = {"type": "m.login.password", "user": username, "password": password} - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) self.assertEqual(channel.code, 403, channel.result) |