diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 0fd55f428a..ee5217b074 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -282,7 +282,11 @@ class AuthTestCase(unittest.TestCase):
)
)
self.store.add_access_token_to_user.assert_called_with(
- USER_ID, token, "DEVICE", None
+ user_id=USER_ID,
+ token=token,
+ device_id="DEVICE",
+ valid_until_ms=None,
+ puppets_user_id=None,
)
def get_user(tok):
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index 4a301b84e1..40abe9d72d 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -15,6 +15,7 @@
from synapse.app.generic_worker import GenericWorkerServer
+from tests.server import make_request
from tests.unittest import HomeserverTestCase
@@ -55,10 +56,8 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
- self.resource = site.resource.children[b"_matrix"].children[b"client"]
- request, channel = self.make_request("PUT", "presence/a/status")
- self.render(request)
+ _, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
# 400 + unrecognised, because nothing is registered
self.assertEqual(channel.code, 400)
@@ -77,10 +76,8 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
- self.resource = site.resource.children[b"_matrix"].children[b"client"]
- request, channel = self.make_request("PUT", "presence/a/status")
- self.render(request)
+ _, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
# 401, because the stub servlet still checks authentication
self.assertEqual(channel.code, 401)
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index c2b10d2c70..ea3be95cf1 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -20,6 +20,7 @@ from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.homeserver import SynapseHomeServer
from synapse.config.server import parse_listener_def
+from tests.server import make_request
from tests.unittest import HomeserverTestCase
@@ -66,16 +67,15 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
try:
- self.resource = site.resource.children[b"_matrix"].children[b"federation"]
+ site.resource.children[b"_matrix"].children[b"federation"]
except KeyError:
if expectation == "no_resource":
return
raise
- request, channel = self.make_request(
- "GET", "/_matrix/federation/v1/openid/userinfo"
+ _, channel = make_request(
+ self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
- self.render(request)
self.assertEqual(channel.code, 401)
@@ -115,15 +115,14 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
try:
- self.resource = site.resource.children[b"_matrix"].children[b"federation"]
+ site.resource.children[b"_matrix"].children[b"federation"]
except KeyError:
if expectation == "no_resource":
return
raise
- request, channel = self.make_request(
- "GET", "/_matrix/federation/v1/openid/userinfo"
+ _, channel = make_request(
+ self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
- self.render(request)
self.assertEqual(channel.code, 401)
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 1471cc1a28..0187f56e21 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -51,7 +51,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
- self.render(request)
self.assertEquals(200, channel.code)
complexity = channel.json_body["v1"]
self.assertTrue(complexity > 0, complexity)
@@ -64,7 +63,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
- self.render(request)
self.assertEquals(200, channel.code)
complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23)
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index da933ecd75..3009fbb6c4 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -51,7 +51,6 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
"/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
query_content,
)
- self.render(request)
self.assertEquals(400, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
@@ -99,7 +98,6 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertEqual(
@@ -132,7 +130,6 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
)
- self.render(request)
self.assertEquals(403, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 72e22d655f..f9e3c7a51f 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -40,7 +40,6 @@ class RoomDirectoryFederationTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/federation/v1/publicRooms"
)
- self.render(request)
self.assertEquals(403, channel.code)
@override_config({"allow_public_rooms_over_federation": True})
@@ -48,5 +47,4 @@ class RoomDirectoryFederationTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/federation/v1/publicRooms"
)
- self.render(request)
self.assertEquals(200, channel.code)
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index b5055e018c..e24ce81284 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -52,7 +52,7 @@ class AuthTestCase(unittest.TestCase):
self.fail("some_user was not in %s" % macaroon.inspect())
def test_macaroon_caveats(self):
- self.hs.clock.now = 5000
+ self.hs.get_clock().now = 5000
token = self.macaroon_generator.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
@@ -78,7 +78,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_short_term_login_token_gives_user_id(self):
- self.hs.clock.now = 1000
+ self.hs.get_clock().now = 1000
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
user_id = yield defer.ensureDeferred(
@@ -87,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected
- self.hs.clock.now = 6000
+ self.hs.get_clock().now = 6000
with self.assertRaises(synapse.api.errors.AuthError):
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 2ce6dc9528..ee6ef5e6fa 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -412,7 +412,6 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
b"directory/room/%23test%3Atest",
('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
)
- self.render(request)
self.assertEquals(403, channel.code, channel.result)
def test_allowed(self):
@@ -423,7 +422,6 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
b"directory/room/%23unofficial_test%3Atest",
('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
@@ -438,7 +436,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.room_list_handler = hs.get_room_list_handler()
@@ -452,7 +449,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
# Room list is enabled so we should get some results
request, channel = self.make_request("GET", b"publicRooms")
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["chunk"]) > 0)
@@ -461,7 +457,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
# Room list disabled so we should get no results
request, channel = self.make_request("GET", b"publicRooms")
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["chunk"]) == 0)
@@ -470,5 +465,4 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
)
- self.render(request)
self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/handlers/test_identity.py b/tests/handlers/test_identity.py
new file mode 100644
index 0000000000..b7d340bcb8
--- /dev/null
+++ b/tests/handlers/test_identity.py
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet import defer
+
+import synapse.rest.admin
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import account
+
+from tests import unittest
+
+
+class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.address = "test@test"
+ self.is_server_name = "testis"
+ self.is_server_url = "https://testis"
+ self.rewritten_is_url = "https://int.testis"
+
+ config = self.default_config()
+ config["trusted_third_party_id_servers"] = [self.is_server_name]
+ config["rewrite_identity_server_urls"] = {
+ self.is_server_url: self.rewritten_is_url
+ }
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.get_json.side_effect = defer.succeed({})
+ mock_http_client.post_json_get_json.return_value = defer.succeed(
+ {"address": self.address, "medium": "email"}
+ )
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ mock_blacklisting_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_blacklisting_http_client.get_json.side_effect = defer.succeed({})
+ mock_blacklisting_http_client.post_json_get_json.return_value = defer.succeed(
+ {"address": self.address, "medium": "email"}
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_identity_handler().blacklisting_http_client = (
+ mock_blacklisting_http_client
+ )
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+
+ def test_rewritten_id_server(self):
+ """
+ Tests that, when validating a 3PID association while rewriting the IS's server
+ name:
+ * the bind request is done against the rewritten hostname
+ * the original, non-rewritten, server name is stored in the database
+ """
+ handler = self.hs.get_identity_handler()
+ post_json_get_json = handler.blacklisting_http_client.post_json_get_json
+ store = self.hs.get_datastore()
+
+ creds = {"sid": "123", "client_secret": "some_secret"}
+
+ # Make sure processing the mocked response goes through.
+ data = self.get_success(
+ handler.bind_threepid(
+ client_secret=creds["client_secret"],
+ sid=creds["sid"],
+ mxid=self.user_id,
+ id_server=self.is_server_name,
+ use_v2=False,
+ )
+ )
+ self.assertEqual(data.get("address"), self.address)
+
+ # Check that the request was done against the rewritten server name.
+ post_json_get_json.assert_called_once_with(
+ "%s/_matrix/identity/api/v1/3pid/bind" % (self.rewritten_is_url,),
+ {
+ "sid": creds["sid"],
+ "client_secret": creds["client_secret"],
+ "mxid": self.user_id,
+ },
+ headers={},
+ )
+
+ # Check that the original server name is saved in the database instead of the
+ # rewritten one.
+ id_servers = self.get_success(
+ store.get_id_servers_user_bound(self.user_id, "email", self.address)
+ )
+ self.assertEqual(id_servers, [self.is_server_name])
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 8b57081cbe..af42775815 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -209,5 +209,4 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", path, content={}, access_token=self.access_token
)
- self.render(request)
self.assertEqual(int(channel.result["code"]), 403)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 0d51705849..a308c46da9 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
from urllib.parse import parse_qs, urlparse
@@ -24,12 +23,8 @@ import pymacaroons
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
-from synapse.handlers.oidc_handler import (
- MappingException,
- OidcError,
- OidcHandler,
- OidcMappingProvider,
-)
+from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
+from synapse.handlers.sso import MappingException
from synapse.types import UserID
from tests.unittest import HomeserverTestCase, override_config
@@ -94,6 +89,14 @@ class TestMappingProviderExtra(TestMappingProvider):
return {"phone": userinfo["phone"]}
+class TestMappingProviderFailures(TestMappingProvider):
+ async def map_user_attributes(self, userinfo, token, failures):
+ return {
+ "localpart": userinfo["username"] + (str(failures) if failures else ""),
+ "display_name": None,
+ }
+
+
def simple_async_mock(return_value=None, raises=None):
# AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args, **kwargs):
@@ -124,22 +127,16 @@ async def get_json(url):
class OidcHandlerTestCase(HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
-
- self.http_client = Mock(spec=["get_json"])
- self.http_client.get_json.side_effect = get_json
- self.http_client.user_agent = "Synapse Test"
-
- config = self.default_config()
+ def default_config(self):
+ config = super().default_config()
config["public_baseurl"] = BASE_URL
- oidc_config = {}
- oidc_config["enabled"] = True
- oidc_config["client_id"] = CLIENT_ID
- oidc_config["client_secret"] = CLIENT_SECRET
- oidc_config["issuer"] = ISSUER
- oidc_config["scopes"] = SCOPES
- oidc_config["user_mapping_provider"] = {
- "module": __name__ + ".TestMappingProvider",
+ oidc_config = {
+ "enabled": True,
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ "issuer": ISSUER,
+ "scopes": SCOPES,
+ "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
}
# Update this config with what's in the default config so that
@@ -147,13 +144,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
oidc_config.update(config.get("oidc_config", {}))
config["oidc_config"] = oidc_config
- hs = self.setup_test_homeserver(
- http_client=self.http_client,
- proxied_http_client=self.http_client,
- config=config,
- )
+ return config
- self.handler = OidcHandler(hs)
+ def make_homeserver(self, reactor, clock):
+
+ self.http_client = Mock(spec=["get_json"])
+ self.http_client.get_json.side_effect = get_json
+ self.http_client.user_agent = "Synapse Test"
+
+ hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
+
+ self.handler = hs.get_oidc_handler()
+ sso_handler = hs.get_sso_handler()
+ # Mock the render error method.
+ self.render_error = Mock(return_value=None)
+ sso_handler.render_error = self.render_error
+
+ # Reduce the number of attempts when generating MXIDs.
+ sso_handler._MAP_USERNAME_RETRIES = 3
return hs
@@ -161,12 +169,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
return patch.dict(self.handler._provider_metadata, values)
def assertRenderedError(self, error, error_description=None):
- args = self.handler._render_error.call_args[0]
+ args = self.render_error.call_args[0]
self.assertEqual(args[1], error)
if error_description is not None:
self.assertEqual(args[2], error_description)
# Reset the render_error mock
- self.handler._render_error.reset_mock()
+ self.render_error.reset_mock()
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
@@ -356,7 +364,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed."""
- self.handler._render_error = Mock()
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
self.get_success(self.handler.handle_oidc_callback(request))
@@ -387,7 +394,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
"preferred_username": "bar",
}
user_id = "@foo:domain.org"
- self.handler._render_error = Mock(return_value=None)
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
@@ -435,7 +441,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_not_called()
- self.handler._render_error.assert_not_called()
+ self.render_error.assert_not_called()
# Handle mapping errors
self.handler._map_userinfo_to_user = simple_async_mock(
@@ -469,7 +475,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_called_once_with(token)
- self.handler._render_error.assert_not_called()
+ self.render_error.assert_not_called()
# Handle userinfo fetching error
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
@@ -485,7 +491,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_session(self):
"""The callback verifies the session presence and validity"""
- self.handler._render_error = Mock(return_value=None)
request = Mock(spec=["args", "getCookie", "addCookie"])
# Missing cookie
@@ -699,19 +704,131 @@ class OidcHandlerTestCase(HomeserverTestCase):
),
MappingException,
)
- self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken")
+ self.assertEqual(
+ str(e.value), "Mapping provider does not support de-duplicating Matrix IDs",
+ )
@override_config({"oidc_config": {"allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self):
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastore()
- user4 = UserID.from_string("@test_user_4:test")
+ user = UserID.from_string("@test_user:test")
+ self.get_success(
+ store.register_user(user_id=user.to_string(), password_hash=None)
+ )
+
+ # Map a user via SSO.
+ userinfo = {
+ "sub": "test",
+ "username": "test_user",
+ }
+ token = {}
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ # Subsequent calls should map to the same mxid.
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ # Note that a second SSO user can be mapped to the same Matrix ID. (This
+ # requires a unique sub, but something that maps to the same matrix ID,
+ # in this case we'll just use the same username. A more realistic example
+ # would be subs which are email addresses, and mapping from the localpart
+ # of the email, e.g. bob@foo.com and bob@bar.com -> @bob:test.)
+ userinfo = {
+ "sub": "test1",
+ "username": "test_user",
+ }
+ token = {}
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ # Register some non-exact matching cases.
+ user2 = UserID.from_string("@TEST_user_2:test")
+ self.get_success(
+ store.register_user(user_id=user2.to_string(), password_hash=None)
+ )
+ user2_caps = UserID.from_string("@test_USER_2:test")
+ self.get_success(
+ store.register_user(user_id=user2_caps.to_string(), password_hash=None)
+ )
+
+ # Attempting to login without matching a name exactly is an error.
+ userinfo = {
+ "sub": "test2",
+ "username": "TEST_USER_2",
+ }
+ e = self.get_failure(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertTrue(
+ str(e.value).startswith(
+ "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
+ )
+ )
+
+ # Logging in when matching a name exactly should work.
+ user2 = UserID.from_string("@TEST_USER_2:test")
+ self.get_success(
+ store.register_user(user_id=user2.to_string(), password_hash=None)
+ )
+
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@TEST_USER_2:test")
+
+ def test_map_userinfo_to_invalid_localpart(self):
+ """If the mapping provider generates an invalid localpart it should be rejected."""
+ userinfo = {
+ "sub": "test2",
+ "username": "föö",
+ }
+ token = {}
+
+ e = self.get_failure(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(str(e.value), "localpart is invalid: föö")
+
+ @override_config(
+ {
+ "oidc_config": {
+ "user_mapping_provider": {
+ "module": __name__ + ".TestMappingProviderFailures"
+ }
+ }
+ }
+ )
+ def test_map_userinfo_to_user_retries(self):
+ """The mapping provider can retry generating an MXID if the MXID is already in use."""
+ store = self.hs.get_datastore()
self.get_success(
- store.register_user(user_id=user4.to_string(), password_hash=None)
+ store.register_user(user_id="@test_user:test", password_hash=None)
)
userinfo = {
- "sub": "test4",
- "username": "test_user_4",
+ "sub": "test",
+ "username": "test_user",
}
token = {}
mxid = self.get_success(
@@ -719,4 +836,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, "user-agent", "10.10.10.10"
)
)
- self.assertEqual(mxid, "@test_user_4:test")
+ # test_user is already taken, so test_user1 gets registered instead.
+ self.assertEqual(mxid, "@test_user1:test")
+
+ # Register all of the potential mxids for a particular OIDC username.
+ self.get_success(
+ store.register_user(user_id="@tester:test", password_hash=None)
+ )
+ for i in range(1, 3):
+ self.get_success(
+ store.register_user(user_id="@tester%d:test" % i, password_hash=None)
+ )
+
+ # Now attempt to map to a username, this will fail since all potential usernames are taken.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ }
+ e = self.get_failure(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(
+ str(e.value), "Unable to generate a Matrix ID from the SSO response"
+ )
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
new file mode 100644
index 0000000000..ceaf0902d2
--- /dev/null
+++ b/tests/handlers/test_password_providers.py
@@ -0,0 +1,580 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the password_auth_provider interface"""
+
+from typing import Any, Type, Union
+
+from mock import Mock
+
+from twisted.internet import defer
+
+import synapse
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import devices
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.server import FakeChannel
+from tests.unittest import override_config
+
+# (possibly experimental) login flows we expect to appear in the list after the normal
+# ones
+ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+
+# a mock instance which the dummy auth providers delegate to, so we can see what's going
+# on
+mock_password_provider = Mock()
+
+
+class PasswordOnlyAuthProvider:
+ """A password_provider which only implements `check_password`."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, account_handler):
+ pass
+
+ def check_password(self, *args):
+ return mock_password_provider.check_password(*args)
+
+
+class CustomAuthProvider:
+ """A password_provider which implements a custom login type."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, account_handler):
+ pass
+
+ def get_supported_login_types(self):
+ return {"test.login_type": ["test_field"]}
+
+ def check_auth(self, *args):
+ return mock_password_provider.check_auth(*args)
+
+
+class PasswordCustomAuthProvider:
+ """A password_provider which implements password login via `check_auth`, as well
+ as a custom type."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, account_handler):
+ pass
+
+ def get_supported_login_types(self):
+ return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
+
+ def check_auth(self, *args):
+ return mock_password_provider.check_auth(*args)
+
+
+def providers_config(*providers: Type[Any]) -> dict:
+ """Returns a config dict that will enable the given password auth providers"""
+ return {
+ "password_providers": [
+ {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
+ for provider in providers
+ ]
+ }
+
+
+class PasswordAuthProviderTests(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def setUp(self):
+ # we use a global mock device, so make sure we are starting with a clean slate
+ mock_password_provider.reset_mock()
+ super().setUp()
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_password_only_auth_provider_login(self):
+ # login flows should only have m.login.password
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # check_password must return an awaitable
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+ channel = self._send_password_login("u", "p")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@u:test", channel.json_body["user_id"])
+ mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
+ mock_password_provider.reset_mock()
+
+ # login with mxid should work too
+ channel = self._send_password_login("@u:bz", "p")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@u:bz", channel.json_body["user_id"])
+ mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
+ mock_password_provider.reset_mock()
+
+ # try a weird username / pass. Honestly it's unclear what we *expect* to happen
+ # in these cases, but at least we can guard against the API changing
+ # unexpectedly
+ channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
+ mock_password_provider.check_password.assert_called_once_with(
+ "@ USER🙂NAME :test", " pASS😢word "
+ )
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_password_only_auth_provider_ui_auth(self):
+ """UI Auth should delegate correctly to the password provider"""
+
+ # create the user, otherwise access doesn't work
+ module_api = self.hs.get_module_api()
+ self.get_success(module_api.register_user("u"))
+
+ # log in twice, to get two devices
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+ tok1 = self.login("u", "p")
+ self.login("u", "p", device_id="dev2")
+ mock_password_provider.reset_mock()
+
+ # have the auth provider deny the request to start with
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+
+ # make the initial request which returns a 401
+ session = self._start_delete_device_session(tok1, "dev2")
+ mock_password_provider.check_password.assert_not_called()
+
+ # Make another request providing the UI auth flow.
+ channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
+ self.assertEqual(channel.code, 401) # XXX why not a 403?
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
+ mock_password_provider.reset_mock()
+
+ # Finally, check the request goes through when we allow it
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+ channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_local_user_fallback_login(self):
+ """rejected login should fall back to local db"""
+ self.register_user("localuser", "localpass")
+
+ # check_password must return an awaitable
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+ channel = self._send_password_login("u", "p")
+ self.assertEqual(channel.code, 403, channel.result)
+
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@localuser:test", channel.json_body["user_id"])
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_local_user_fallback_ui_auth(self):
+ """rejected login should fall back to local db"""
+ self.register_user("localuser", "localpass")
+
+ # have the auth provider deny the request
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+
+ # log in twice, to get two devices
+ tok1 = self.login("localuser", "localpass")
+ self.login("localuser", "localpass", device_id="dev2")
+ mock_password_provider.check_password.reset_mock()
+
+ # first delete should give a 401
+ session = self._start_delete_device_session(tok1, "dev2")
+ mock_password_provider.check_password.assert_not_called()
+
+ # Wrong password
+ channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx")
+ self.assertEqual(channel.code, 401) # XXX why not a 403?
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "xxx"
+ )
+ mock_password_provider.reset_mock()
+
+ # Right password
+ channel = self._authed_delete_device(
+ tok1, "dev2", session, "localuser", "localpass"
+ )
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "localpass"
+ )
+
+ @override_config(
+ {
+ **providers_config(PasswordOnlyAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_no_local_user_fallback_login(self):
+ """localdb_enabled can block login with the local password
+ """
+ self.register_user("localuser", "localpass")
+
+ # check_password must return an awaitable
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "localpass"
+ )
+
+ @override_config(
+ {
+ **providers_config(PasswordOnlyAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_no_local_user_fallback_ui_auth(self):
+ """localdb_enabled can block ui auth with the local password
+ """
+ self.register_user("localuser", "localpass")
+
+ # allow login via the auth provider
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+
+ # log in twice, to get two devices
+ tok1 = self.login("localuser", "p")
+ self.login("localuser", "p", device_id="dev2")
+ mock_password_provider.check_password.reset_mock()
+
+ # first delete should give a 401
+ channel = self._delete_device(tok1, "dev2")
+ self.assertEqual(channel.code, 401)
+ # m.login.password UIA is permitted because the auth provider allows it,
+ # even though the localdb does not.
+ self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
+ session = channel.json_body["session"]
+ mock_password_provider.check_password.assert_not_called()
+
+ # now try deleting with the local password
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+ channel = self._authed_delete_device(
+ tok1, "dev2", session, "localuser", "localpass"
+ )
+ self.assertEqual(channel.code, 401) # XXX why not a 403?
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "localpass"
+ )
+
+ @override_config(
+ {
+ **providers_config(PasswordOnlyAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_auth_disabled(self):
+ """password auth doesn't work if it's disabled across the board"""
+ # login flows should be empty
+ flows = self._get_login_flows()
+ self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("u", "p")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_password.assert_not_called()
+
+ @override_config(providers_config(CustomAuthProvider))
+ def test_custom_auth_provider_login(self):
+ # login flows should have the custom flow and m.login.password, since we
+ # haven't disabled local password lookup.
+ # (password must come first, because reasons)
+ flows = self._get_login_flows()
+ self.assertEqual(
+ flows,
+ [{"type": "m.login.password"}, {"type": "test.login_type"}]
+ + ADDITIONAL_LOGIN_FLOWS,
+ )
+
+ # login with missing param should be rejected
+ channel = self._send_login("test.login_type", "u")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+ channel = self._send_login("test.login_type", "u", test_field="y")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@user:bz", channel.json_body["user_id"])
+ mock_password_provider.check_auth.assert_called_once_with(
+ "u", "test.login_type", {"test_field": "y"}
+ )
+ mock_password_provider.reset_mock()
+
+ # try a weird username. Again, it's unclear what we *expect* to happen
+ # in these cases, but at least we can guard against the API changing
+ # unexpectedly
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ "@ MALFORMED! :bz"
+ )
+ channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
+ mock_password_provider.check_auth.assert_called_once_with(
+ " USER🙂NAME ", "test.login_type", {"test_field": " abc "}
+ )
+
+ @override_config(providers_config(CustomAuthProvider))
+ def test_custom_auth_provider_ui_auth(self):
+ # register the user and log in twice, to get two devices
+ self.register_user("localuser", "localpass")
+ tok1 = self.login("localuser", "localpass")
+ self.login("localuser", "localpass", device_id="dev2")
+
+ # make the initial request which returns a 401
+ channel = self._delete_device(tok1, "dev2")
+ self.assertEqual(channel.code, 401)
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+ self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
+ session = channel.json_body["session"]
+
+ # missing param
+ body = {
+ "auth": {
+ "type": "test.login_type",
+ "identifier": {"type": "m.id.user", "user": "localuser"},
+ "session": session,
+ },
+ }
+
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 400)
+ # there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should
+ # use it...
+ self.assertIn("Missing parameters", channel.json_body["error"])
+ mock_password_provider.check_auth.assert_not_called()
+ mock_password_provider.reset_mock()
+
+ # right params, but authing as the wrong user
+ mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+ body["auth"]["test_field"] = "foo"
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_auth.assert_called_once_with(
+ "localuser", "test.login_type", {"test_field": "foo"}
+ )
+ mock_password_provider.reset_mock()
+
+ # and finally, succeed
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ "@localuser:test"
+ )
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_auth.assert_called_once_with(
+ "localuser", "test.login_type", {"test_field": "foo"}
+ )
+
+ @override_config(providers_config(CustomAuthProvider))
+ def test_custom_auth_provider_callback(self):
+ callback = Mock(return_value=defer.succeed(None))
+
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ ("@user:bz", callback)
+ )
+ channel = self._send_login("test.login_type", "u", test_field="y")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@user:bz", channel.json_body["user_id"])
+ mock_password_provider.check_auth.assert_called_once_with(
+ "u", "test.login_type", {"test_field": "y"}
+ )
+
+ # check the args to the callback
+ callback.assert_called_once()
+ call_args, call_kwargs = callback.call_args
+ # should be one positional arg
+ self.assertEqual(len(call_args), 1)
+ self.assertEqual(call_args[0]["user_id"], "@user:bz")
+ for p in ["user_id", "access_token", "device_id", "home_server"]:
+ self.assertIn(p, call_args[0])
+
+ @override_config(
+ {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
+ )
+ def test_custom_auth_password_disabled(self):
+ """Test login with a custom auth provider where password login is disabled"""
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ @override_config(
+ {
+ **providers_config(PasswordCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_custom_auth_password_disabled_login(self):
+ """log in with a custom auth provider which implements password, but password
+ login is disabled"""
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ @override_config(
+ {
+ **providers_config(PasswordCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_custom_auth_password_disabled_ui_auth(self):
+ """UI Auth with a custom auth provider which implements password, but password
+ login is disabled"""
+ # register the user and log in twice via the test login type to get two devices,
+ self.register_user("localuser", "localpass")
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ "@localuser:test"
+ )
+ channel = self._send_login("test.login_type", "localuser", test_field="")
+ self.assertEqual(channel.code, 200, channel.result)
+ tok1 = channel.json_body["access_token"]
+
+ channel = self._send_login(
+ "test.login_type", "localuser", test_field="", device_id="dev2"
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # make the initial request which returns a 401
+ channel = self._delete_device(tok1, "dev2")
+ self.assertEqual(channel.code, 401)
+ # Ensure that flows are what is expected. In particular, "password" should *not*
+ # be present.
+ self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
+ session = channel.json_body["session"]
+
+ mock_password_provider.reset_mock()
+
+ # check that auth with password is rejected
+ body = {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": "localuser"},
+ "password": "localpass",
+ "session": session,
+ },
+ }
+
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 400)
+ self.assertEqual(
+ "Password login has been disabled.", channel.json_body["error"]
+ )
+ mock_password_provider.check_auth.assert_not_called()
+ mock_password_provider.reset_mock()
+
+ # successful auth
+ body["auth"]["type"] = "test.login_type"
+ body["auth"]["test_field"] = "x"
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_auth.assert_called_once_with(
+ "localuser", "test.login_type", {"test_field": "x"}
+ )
+
+ @override_config(
+ {
+ **providers_config(CustomAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_no_local_user_fallback(self):
+ """Test login with a custom auth provider where the local db is disabled"""
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # password login shouldn't work and should be rejected with a 400
+ # ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+
+ def _get_login_flows(self) -> JsonDict:
+ _, channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body["flows"]
+
+ def _send_password_login(self, user: str, password: str) -> FakeChannel:
+ return self._send_login(type="m.login.password", user=user, password=password)
+
+ def _send_login(self, type, user, **params) -> FakeChannel:
+ params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
+ _, channel = self.make_request("POST", "/_matrix/client/r0/login", params)
+ return channel
+
+ def _start_delete_device_session(self, access_token, device_id) -> str:
+ """Make an initial delete device request, and return the UI Auth session ID"""
+ channel = self._delete_device(access_token, device_id)
+ self.assertEqual(channel.code, 401)
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+ return channel.json_body["session"]
+
+ def _authed_delete_device(
+ self,
+ access_token: str,
+ device_id: str,
+ session: str,
+ user_id: str,
+ password: str,
+ ) -> FakeChannel:
+ """Make a delete device request, authenticating with the given uid/password"""
+ return self._delete_device(
+ access_token,
+ device_id,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": user_id},
+ "password": password,
+ "session": session,
+ },
+ },
+ )
+
+ def _delete_device(
+ self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
+ ) -> FakeChannel:
+ """Delete an individual device."""
+ _, channel = self.make_request(
+ "DELETE", "devices/" + device, body, access_token=access_token
+ )
+ return channel
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index a69fa28b41..1999bcbb38 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -65,7 +65,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_name(self):
yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
)
displayname = yield defer.ensureDeferred(
@@ -113,7 +113,7 @@ class ProfileTestCase(unittest.TestCase):
# Setting displayname for the first time is allowed
yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
)
self.assertEquals(
@@ -166,7 +166,7 @@ class ProfileTestCase(unittest.TestCase):
def test_incoming_fed_query(self):
yield defer.ensureDeferred(self.store.create_profile("caroline"))
yield defer.ensureDeferred(
- self.store.set_profile_displayname("caroline", "Caroline")
+ self.store.set_profile_displayname("caroline", "Caroline", 1)
)
response = yield defer.ensureDeferred(
@@ -181,7 +181,7 @@ class ProfileTestCase(unittest.TestCase):
def test_get_my_avatar(self):
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ self.frank.localpart, "http://my.server/me.png", 1
)
)
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
@@ -232,7 +232,7 @@ class ProfileTestCase(unittest.TestCase):
# Setting displayname for the first time is allowed
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ self.frank.localpart, "http://my.server/me.png", 1
)
)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index bdf3d0a8a2..d979aadcd6 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -18,9 +18,15 @@ from mock import Mock
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
+from synapse.http.site import SynapseRequest
+from synapse.rest.client.v2_alpha.register import (
+ _map_email_to_displayname,
+ register_servlets,
+)
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserID, create_requester
+from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import mock_getRawHeaders
@@ -31,6 +37,10 @@ from .. import unittest
class RegistrationTestCase(unittest.HomeserverTestCase):
""" Tests the RegistrationHandler. """
+ servlets = [
+ register_servlets,
+ ]
+
def make_homeserver(self, reactor, clock):
hs_config = self.default_config()
@@ -517,6 +527,103 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(requester.shadow_banned)
+ def test_email_to_displayname_mapping(self):
+ """Test that custom emails are mapped to new user displaynames correctly"""
+ self._check_mapping(
+ "jack-phillips.rivers@big-org.com", "Jack-Phillips Rivers [Big-Org]"
+ )
+
+ self._check_mapping("bob.jones@matrix.org", "Bob Jones [Tchap Admin]")
+
+ self._check_mapping("bob-jones.blabla@gouv.fr", "Bob-Jones Blabla [Gouv]")
+
+ # Multibyte unicode characters
+ self._check_mapping(
+ "j\u030a\u0065an-poppy.seed@example.com",
+ "J\u030a\u0065an-Poppy Seed [Example]",
+ )
+
+ def _check_mapping(self, i, expected):
+ result = _map_email_to_displayname(i)
+ self.assertEqual(result, expected)
+
+ @override_config(
+ {
+ "bind_new_user_emails_to_sydent": "https://is.example.com",
+ "registrations_require_3pid": ["email"],
+ "account_threepid_delegates": {},
+ "email": {
+ "smtp_host": "127.0.0.1",
+ "smtp_port": 20,
+ "require_transport_security": False,
+ "smtp_user": None,
+ "smtp_pass": None,
+ "notif_from": "test@example.com",
+ },
+ "public_baseurl": "http://localhost",
+ }
+ )
+ def test_user_email_bound_via_sydent_internal_api(self):
+ """Tests that emails are bound after registration if this option is set"""
+ # Register user with an email address
+ email = "alice@example.com"
+
+ # Mock Synapse's threepid validator
+ get_threepid_validation_session = Mock(
+ return_value=make_awaitable(
+ {"medium": "email", "address": email, "validated_at": 0}
+ )
+ )
+ self.store.get_threepid_validation_session = get_threepid_validation_session
+ delete_threepid_session = Mock(return_value=make_awaitable(None))
+ self.store.delete_threepid_session = delete_threepid_session
+
+ # Mock Synapse's http json post method to check for the internal bind call
+ post_json_get_json = Mock(return_value=make_awaitable(None))
+ self.hs.get_simple_http_client().post_json_get_json = post_json_get_json
+
+ # Retrieve a UIA session ID
+ channel = self.uia_register(
+ 401, {"username": "alice", "password": "nobodywillguessthis"}
+ )
+ session_id = channel.json_body["session"]
+
+ # Register our email address using the fake validation session above
+ channel = self.uia_register(
+ 200,
+ {
+ "username": "alice",
+ "password": "nobodywillguessthis",
+ "auth": {
+ "session": session_id,
+ "type": "m.login.email.identity",
+ "threepid_creds": {"sid": "blabla", "client_secret": "blablabla"},
+ },
+ },
+ )
+ self.assertEqual(channel.json_body["user_id"], "@alice:test")
+
+ # Check that a bind attempt was made to our fake identity server
+ post_json_get_json.assert_called_with(
+ "https://is.example.com/_matrix/identity/internal/bind",
+ {"address": "alice@example.com", "medium": "email", "mxid": "@alice:test"},
+ )
+
+ # Check that we stored a mapping of this bind
+ bound_threepids = self.get_success(
+ self.store.user_get_bound_threepids("@alice:test")
+ )
+ self.assertListEqual(bound_threepids, [{"medium": "email", "address": email}])
+
+ def uia_register(self, expected_response: int, body: dict) -> FakeChannel:
+ """Make a register request."""
+ request, channel = self.make_request(
+ "POST", "register", body
+ ) # type: SynapseRequest, FakeChannel
+
+ self.assertEqual(request.code, expected_response)
+ return channel
+
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
new file mode 100644
index 0000000000..45dc17aba5
--- /dev/null
+++ b/tests/handlers/test_saml.py
@@ -0,0 +1,196 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import attr
+
+from synapse.api.errors import RedirectException
+from synapse.handlers.sso import MappingException
+
+from tests.unittest import HomeserverTestCase, override_config
+
+# These are a few constants that are used as config parameters in the tests.
+BASE_URL = "https://synapse/"
+
+
+@attr.s
+class FakeAuthnResponse:
+ ava = attr.ib(type=dict)
+
+
+class TestMappingProvider:
+ def __init__(self, config, module):
+ pass
+
+ @staticmethod
+ def parse_config(config):
+ return
+
+ @staticmethod
+ def get_saml_attributes(config):
+ return {"uid"}, {"displayName"}
+
+ def get_remote_user_id(self, saml_response, client_redirect_url):
+ return saml_response.ava["uid"]
+
+ def saml_response_to_user_attributes(
+ self, saml_response, failures, client_redirect_url
+ ):
+ localpart = saml_response.ava["username"] + (str(failures) if failures else "")
+ return {"mxid_localpart": localpart, "displayname": None}
+
+
+class TestRedirectMappingProvider(TestMappingProvider):
+ def saml_response_to_user_attributes(
+ self, saml_response, failures, client_redirect_url
+ ):
+ raise RedirectException(b"https://custom-saml-redirect/")
+
+
+class SamlHandlerTestCase(HomeserverTestCase):
+ def default_config(self):
+ config = super().default_config()
+ config["public_baseurl"] = BASE_URL
+ saml_config = {
+ "sp_config": {"metadata": {}},
+ # Disable grandfathering.
+ "grandfathered_mxid_source_attribute": None,
+ "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
+ }
+
+ # Update this config with what's in the default config so that
+ # override_config works as expected.
+ saml_config.update(config.get("saml2_config", {}))
+ config["saml2_config"] = saml_config
+
+ return config
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+
+ self.handler = hs.get_saml_handler()
+
+ # Reduce the number of attempts when generating MXIDs.
+ sso_handler = hs.get_sso_handler()
+ sso_handler._MAP_USERNAME_RETRIES = 3
+
+ return hs
+
+ def test_map_saml_response_to_user(self):
+ """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
+ saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
+ # The redirect_url doesn't matter with the default user mapping provider.
+ redirect_url = ""
+ mxid = self.get_success(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
+ def test_map_saml_response_to_existing_user(self):
+ """Existing users can log in with SAML account."""
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.register_user(user_id="@test_user:test", password_hash=None)
+ )
+
+ # Map a user via SSO.
+ saml_response = FakeAuthnResponse(
+ {"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
+ )
+ redirect_url = ""
+ mxid = self.get_success(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ # Subsequent calls should map to the same mxid.
+ mxid = self.get_success(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ def test_map_saml_response_to_invalid_localpart(self):
+ """If the mapping provider generates an invalid localpart it should be rejected."""
+ saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
+ redirect_url = ""
+ e = self.get_failure(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(str(e.value), "localpart is invalid: föö")
+
+ def test_map_saml_response_to_user_retries(self):
+ """The mapping provider can retry generating an MXID if the MXID is already in use."""
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.register_user(user_id="@test_user:test", password_hash=None)
+ )
+ saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
+ redirect_url = ""
+ mxid = self.get_success(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ )
+ )
+ # test_user is already taken, so test_user1 gets registered instead.
+ self.assertEqual(mxid, "@test_user1:test")
+
+ # Register all of the potential mxids for a particular SAML username.
+ self.get_success(
+ store.register_user(user_id="@tester:test", password_hash=None)
+ )
+ for i in range(1, 3):
+ self.get_success(
+ store.register_user(user_id="@tester%d:test" % i, password_hash=None)
+ )
+
+ # Now attempt to map to a username, this will fail since all potential usernames are taken.
+ saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
+ e = self.get_failure(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(
+ str(e.value), "Unable to generate a Matrix ID from the SSO response"
+ )
+
+ @override_config(
+ {
+ "saml2_config": {
+ "user_mapping_provider": {
+ "module": __name__ + ".TestRedirectMappingProvider"
+ },
+ }
+ }
+ )
+ def test_map_saml_response_redirect(self):
+ saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
+ redirect_url = ""
+ e = self.get_failure(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ ),
+ RedirectException,
+ )
+ self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 312c0a0d41..0229f58315 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -21,8 +21,14 @@ from tests import unittest
# The expected number of state events in a fresh public room.
EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM = 5
+
# The expected number of state events in a fresh private room.
-EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6
+#
+# Note: we increase this by 2 on the dinsic branch as we send
+# a "im.vector.room.access_rules" state event into new private rooms,
+# and an encryption state event as all private rooms are encrypted
+# by default
+EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 7
class StatsRoomTests(unittest.HomeserverTestCase):
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index e178d7765b..e62586142e 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -16,7 +16,7 @@
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.handlers.sync import SyncConfig
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
import tests.unittest
import tests.utils
@@ -38,6 +38,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
user_id1 = "@user1:test"
user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1)
+ requester = create_requester(user_id1)
self.reactor.advance(100) # So we get not 0 time
self.auth_blocking._limit_usage_by_mau = True
@@ -45,21 +46,26 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Check that the happy case does not throw errors
self.get_success(self.store.upsert_monthly_active_user(user_id1))
- self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
+ self.get_success(
+ self.sync_handler.wait_for_sync_for_user(requester, sync_config)
+ )
# Test that global lock works
self.auth_blocking._hs_disabled = True
e = self.get_failure(
- self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+ self.sync_handler.wait_for_sync_for_user(requester, sync_config),
+ ResourceLimitError,
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.auth_blocking._hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
+ requester = create_requester(user_id2)
e = self.get_failure(
- self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+ self.sync_handler.wait_for_sync_for_user(requester, sync_config),
+ ResourceLimitError,
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 16ff2e22d2..abbdf2d524 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -228,7 +228,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
),
federation_auth_origin=b"farm",
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 87be94111f..4f6a912ac4 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
import synapse.rest.admin
from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import user_directory
+from synapse.rest.client.v2_alpha import account, account_validity, user_directory
from synapse.storage.roommember import ProfileInfo
from tests import unittest
@@ -537,7 +537,6 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", b"user_directory/search", b'{"search_term":"user2"}'
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) > 0)
@@ -546,6 +545,134 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", b"user_directory/search", b'{"search_term":"user2"}'
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) == 0)
+
+
+class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ account.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+
+ # Set accounts to expire after a week
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ return config
+
+ def prepare(self, reactor, clock, hs):
+ super(UserInfoTestCase, self).prepare(reactor, clock, hs)
+ self.store = hs.get_datastore()
+ self.handler = hs.get_user_directory_handler()
+
+ def test_user_info(self):
+ """Test /users/info for local users from the Client-Server API"""
+ user_one, user_two, user_three, user_three_token = self.setup_test_users()
+
+ # Request info about each user from user_three
+ request, channel = self.make_request(
+ "POST",
+ path="/_matrix/client/unstable/users/info",
+ content={"user_ids": [user_one, user_two, user_three]},
+ access_token=user_three_token,
+ shorthand=False,
+ )
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Check the state of user_one matches
+ user_one_info = channel.json_body[user_one]
+ self.assertTrue(user_one_info["deactivated"])
+ self.assertFalse(user_one_info["expired"])
+
+ # Check the state of user_two matches
+ user_two_info = channel.json_body[user_two]
+ self.assertFalse(user_two_info["deactivated"])
+ self.assertTrue(user_two_info["expired"])
+
+ # Check the state of user_three matches
+ user_three_info = channel.json_body[user_three]
+ self.assertFalse(user_three_info["deactivated"])
+ self.assertFalse(user_three_info["expired"])
+
+ def test_user_info_federation(self):
+ """Test that /users/info can be called from the Federation API, and
+ and that we can query remote users from the Client-Server API
+ """
+ user_one, user_two, user_three, user_three_token = self.setup_test_users()
+
+ # Request information about our local users from the perspective of a remote server
+ request, channel = self.make_request(
+ "POST",
+ path="/_matrix/federation/unstable/users/info",
+ content={"user_ids": [user_one, user_two, user_three]},
+ )
+ self.assertEquals(200, channel.code)
+
+ # Check the state of user_one matches
+ user_one_info = channel.json_body[user_one]
+ self.assertTrue(user_one_info["deactivated"])
+ self.assertFalse(user_one_info["expired"])
+
+ # Check the state of user_two matches
+ user_two_info = channel.json_body[user_two]
+ self.assertFalse(user_two_info["deactivated"])
+ self.assertTrue(user_two_info["expired"])
+
+ # Check the state of user_three matches
+ user_three_info = channel.json_body[user_three]
+ self.assertFalse(user_three_info["deactivated"])
+ self.assertFalse(user_three_info["expired"])
+
+ def setup_test_users(self):
+ """Create an admin user and three test users, each with a different state"""
+
+ # Create an admin user to expire other users with
+ self.register_user("admin", "adminpassword", admin=True)
+ admin_token = self.login("admin", "adminpassword")
+
+ # Create three users
+ user_one = self.register_user("alice", "pass")
+ user_one_token = self.login("alice", "pass")
+ user_two = self.register_user("bob", "pass")
+ user_three = self.register_user("carl", "pass")
+ user_three_token = self.login("carl", "pass")
+
+ # Deactivate user_one
+ self.deactivate(user_one, user_one_token)
+
+ # Expire user_two
+ self.expire(user_two, admin_token)
+
+ # Do nothing to user_three
+
+ return user_one, user_two, user_three, user_three_token
+
+ def expire(self, user_id_to_expire, admin_tok):
+ url = "/_synapse/admin/v1/account_validity/validity"
+ request_data = {
+ "user_id": user_id_to_expire,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=admin_tok
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def deactivate(self, user_id, tok):
+ request_data = {
+ "auth": {"type": "m.login.password", "user": user_id, "password": "pass"},
+ "erase": False,
+ }
+ request, channel = self.make_request(
+ "POST", "account/deactivate", request_data, access_token=tok
+ )
+ self.assertEqual(request.code, 200)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 8b5ad4574f..c3f7a28dcc 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -101,7 +101,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.agent = MatrixFederationAgent(
reactor=self.reactor,
- tls_client_options_factory=self.tls_factory,
+ tls_client_options_factory=FederationPolicyForHTTPS(config),
user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided.
_srv_resolver=self.mock_resolver,
_well_known_resolver=self.well_known_resolver,
diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py
index 62d36c2906..05e9c449be 100644
--- a/tests/http/test_additional_resource.py
+++ b/tests/http/test_additional_resource.py
@@ -17,6 +17,7 @@
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import respond_with_json
+from tests.server import FakeSite, make_request
from tests.unittest import HomeserverTestCase
@@ -43,20 +44,18 @@ class AdditionalResourceTests(HomeserverTestCase):
def test_async(self):
handler = _AsyncTestCustomEndpoint({}, None).handle_request
- self.resource = AdditionalResource(self.hs, handler)
+ resource = AdditionalResource(self.hs, handler)
- request, channel = self.make_request("GET", "/")
- self.render(request)
+ request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
def test_sync(self):
handler = _SyncTestCustomEndpoint({}, None).handle_request
- self.resource = AdditionalResource(self.hs, handler)
+ resource = AdditionalResource(self.hs, handler)
- request, channel = self.make_request("GET", "/")
- self.render(request)
+ request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 9b573ac24d..27206ca3db 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -94,12 +94,13 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertFalse(hasattr(event, "state_key"))
self.assertDictEqual(event.content, content)
+ expected_requester = create_requester(
+ user_id, authenticated_entity=self.hs.hostname
+ )
+
# Check that the event was sent
self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
- create_requester(user_id),
- event_dict,
- ratelimit=False,
- ignore_shadow_ban=True,
+ expected_requester, event_dict, ratelimit=False, ignore_shadow_ban=True,
)
# Create and send a state event
@@ -128,7 +129,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the event was sent
self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
- create_requester(user_id),
+ expected_requester,
{
"type": "m.room.power_levels",
"content": content,
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 8571924b29..cd46c485dd 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from mock import Mock
from twisted.internet.defer import Deferred
@@ -20,8 +19,9 @@ from twisted.internet.defer import Deferred
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import receipts
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
class HTTPPusherTests(HomeserverTestCase):
@@ -29,6 +29,7 @@ class HTTPPusherTests(HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
+ receipts.register_servlets,
]
user_id = True
hijack_auth = False
@@ -346,8 +347,8 @@ class HTTPPusherTests(HomeserverTestCase):
self.assertEqual(len(self.push_attempts), 2)
self.assertEqual(self.push_attempts[1][1], "example.com")
- # check that this is low-priority
- self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+ # check that this is high-priority
+ self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
def test_sends_high_priority_for_mention(self):
"""
@@ -418,8 +419,8 @@ class HTTPPusherTests(HomeserverTestCase):
self.assertEqual(len(self.push_attempts), 2)
self.assertEqual(self.push_attempts[1][1], "example.com")
- # check that this is low-priority
- self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+ # check that this is high-priority
+ self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
def test_sends_high_priority_for_atroom(self):
"""
@@ -497,5 +498,163 @@ class HTTPPusherTests(HomeserverTestCase):
self.assertEqual(len(self.push_attempts), 2)
self.assertEqual(self.push_attempts[1][1], "example.com")
- # check that this is low-priority
- self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+ # check that this is high-priority
+ self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
+
+ def test_push_unread_count_group_by_room(self):
+ """
+ The HTTP pusher will group unread count by number of unread rooms.
+ """
+ # Carry out common push count tests and setup
+ self._test_push_unread_count()
+
+ # Carry out our option-value specific test
+ #
+ # This push should still only contain an unread count of 1 (for 1 unread room)
+ self.assertEqual(
+ self.push_attempts[5][2]["notification"]["counts"]["unread"], 1
+ )
+
+ @override_config({"push": {"group_unread_count_by_room": False}})
+ def test_push_unread_count_message_count(self):
+ """
+ The HTTP pusher will send the total unread message count.
+ """
+ # Carry out common push count tests and setup
+ self._test_push_unread_count()
+
+ # Carry out our option-value specific test
+ #
+ # We're counting every unread message, so there should now be 4 since the
+ # last read receipt
+ self.assertEqual(
+ self.push_attempts[5][2]["notification"]["counts"]["unread"], 4
+ )
+
+ def _test_push_unread_count(self):
+ """
+ Tests that the correct unread count appears in sent push notifications
+
+ Note that:
+ * Sending messages will cause push notifications to go out to relevant users
+ * Sending a read receipt will cause a "badge update" notification to go out to
+ the user that sent the receipt
+ """
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("other_user", "pass")
+ other_access_token = self.login("other_user", "pass")
+
+ # Create a room (as other_user)
+ room_id = self.helper.create_room_as(other_user_id, tok=other_access_token)
+
+ # The user to get notified joins
+ self.helper.join(room=room_id, user=user_id, tok=access_token)
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_tuple.token_id
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "example.com"},
+ )
+ )
+
+ # Send a message
+ response = self.helper.send(
+ room_id, body="Hello there!", tok=other_access_token
+ )
+ # To get an unread count, the user who is getting notified has to have a read
+ # position in the room. We'll set the read position to this event in a moment
+ first_message_event_id = response["event_id"]
+
+ # Advance time a bit (so the pusher will register something has happened) and
+ # make the push succeed
+ self.push_attempts[0][0].callback({})
+ self.pump()
+
+ # Check our push made it
+ self.assertEqual(len(self.push_attempts), 1)
+ self.assertEqual(self.push_attempts[0][1], "example.com")
+
+ # Check that the unread count for the room is 0
+ #
+ # The unread count is zero as the user has no read receipt in the room yet
+ self.assertEqual(
+ self.push_attempts[0][2]["notification"]["counts"]["unread"], 0
+ )
+
+ # Now set the user's read receipt position to the first event
+ #
+ # This will actually trigger a new notification to be sent out so that
+ # even if the user does not receive another message, their unread
+ # count goes down
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id),
+ {},
+ access_token=access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Advance time and make the push succeed
+ self.push_attempts[1][0].callback({})
+ self.pump()
+
+ # Unread count is still zero as we've read the only message in the room
+ self.assertEqual(len(self.push_attempts), 2)
+ self.assertEqual(
+ self.push_attempts[1][2]["notification"]["counts"]["unread"], 0
+ )
+
+ # Send another message
+ self.helper.send(
+ room_id, body="How's the weather today?", tok=other_access_token
+ )
+
+ # Advance time and make the push succeed
+ self.push_attempts[2][0].callback({})
+ self.pump()
+
+ # This push should contain an unread count of 1 as there's now been one
+ # message since our last read receipt
+ self.assertEqual(len(self.push_attempts), 3)
+ self.assertEqual(
+ self.push_attempts[2][2]["notification"]["counts"]["unread"], 1
+ )
+
+ # Since we're grouping by room, sending more messages shouldn't increase the
+ # unread count, as they're all being sent in the same room
+ self.helper.send(room_id, body="Hello?", tok=other_access_token)
+
+ # Advance time and make the push succeed
+ self.pump()
+ self.push_attempts[3][0].callback({})
+
+ self.helper.send(room_id, body="Hello??", tok=other_access_token)
+
+ # Advance time and make the push succeed
+ self.pump()
+ self.push_attempts[4][0].callback({})
+
+ self.helper.send(room_id, body="HELLO???", tok=other_access_token)
+
+ # Advance time and make the push succeed
+ self.pump()
+ self.push_attempts[5][0].callback({})
+
+ self.assertEqual(len(self.push_attempts), 6)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 5c633ac6df..295c5d58a6 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -36,7 +36,7 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
-from tests.server import FakeTransport, render
+from tests.server import FakeTransport
try:
import hiredis
@@ -78,7 +78,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
self.test_handler = self._build_replication_data_handler()
- self.worker_hs.replication_data_handler = self.test_handler
+ self.worker_hs._replication_data_handler = self.test_handler
repl_handler = ReplicationCommandHandler(self.worker_hs)
self.client = ClientReplicationStreamProtocol(
@@ -240,8 +240,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
lambda: self._handle_http_replication_attempt(self.hs, 8765),
)
- def create_test_json_resource(self):
- """Overrides `HomeserverTestCase.create_test_json_resource`.
+ def create_test_resource(self):
+ """Overrides `HomeserverTestCase.create_test_resource`.
"""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
@@ -347,9 +347,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config["worker_replication_http_port"] = "8765"
return config
- def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
- render(request, self._hs_to_site[worker_hs].resource, self.reactor)
-
def replicate(self):
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 86c03fd89c..96801db473 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -20,7 +20,7 @@ from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
-from tests.server import FakeChannel
+from tests.server import FakeChannel, make_request
logger = logging.getLogger(__name__)
@@ -46,23 +46,28 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
"""Test that registration works when using a single client reader worker.
"""
worker_hs = self.make_worker_hs("synapse.app.client_reader")
+ site = self._hs_to_site[worker_hs]
- request_1, channel_1 = self.make_request(
+ request_1, channel_1 = make_request(
+ self.reactor,
+ site,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel
- self.render_on_worker(worker_hs, request_1)
self.assertEqual(request_1.code, 401)
# Grab the session
session = channel_1.json_body["session"]
# also complete the dummy auth
- request_2, channel_2 = self.make_request(
- "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ request_2, channel_2 = make_request(
+ self.reactor,
+ site,
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel
- self.render_on_worker(worker_hs, request_2)
self.assertEqual(request_2.code, 200)
# We're given a registered user.
@@ -74,22 +79,28 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
- request_1, channel_1 = self.make_request(
+ site_1 = self._hs_to_site[worker_hs_1]
+ request_1, channel_1 = make_request(
+ self.reactor,
+ site_1,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel
- self.render_on_worker(worker_hs_1, request_1)
self.assertEqual(request_1.code, 401)
# Grab the session
session = channel_1.json_body["session"]
# also complete the dummy auth
- request_2, channel_2 = self.make_request(
- "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ site_2 = self._hs_to_site[worker_hs_2]
+ request_2, channel_2 = make_request(
+ self.reactor,
+ site_2,
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel
- self.render_on_worker(worker_hs_2, request_2)
self.assertEqual(request_2.code, 200)
# We're given a registered user.
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 77c261dbf7..48b574ccbe 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -28,7 +28,7 @@ from synapse.server import HomeServer
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.server import FakeChannel, FakeTransport
+from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
logger = logging.getLogger(__name__)
@@ -67,14 +67,16 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
The channel for the *client* request and the *outbound* request for
the media which the caller should respond to.
"""
-
- request, channel = self.make_request(
+ resource = hs.get_media_repository_resource().children[b"download"]
+ _, channel = make_request(
+ self.reactor,
+ FakeSite(resource),
"GET",
"/{}/{}".format(target, media_id),
shorthand=False,
access_token=self.access_token,
+ await_result=False,
)
- request.render(hs.get_media_repository_resource().children[b"download"])
self.pump()
clients = self.reactor.tcpClients
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 82cf033d4e..77fc3856d5 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -22,6 +22,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import make_request
from tests.utils import USE_POSTGRES_FOR_TESTS
logger = logging.getLogger(__name__)
@@ -148,6 +149,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
sync_hs = self.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "sync"},
)
+ sync_hs_site = self._hs_to_site[sync_hs]
# Specially selected room IDs that get persisted on different workers.
room_id1 = "!foo:test"
@@ -178,8 +180,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
)
# Do an initial sync so that we're up to date.
- request, channel = self.make_request("GET", "/sync", access_token=access_token)
- self.render_on_worker(sync_hs, request)
+ request, channel = make_request(
+ self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
+ )
next_batch = channel.json_body["next_batch"]
# We now gut wrench into the events stream MultiWriterIdGenerator on
@@ -203,10 +206,13 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Check that syncing still gets the new event, despite the gap in the
# stream IDs.
- request, channel = self.make_request(
- "GET", "/sync?since={}".format(next_batch), access_token=access_token
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/sync?since={}".format(next_batch),
+ access_token=access_token,
)
- self.render_on_worker(sync_hs, request)
# We should only see the new event and nothing else
self.assertIn(room_id1, channel.json_body["rooms"]["join"])
@@ -230,12 +236,13 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
first_event_in_room2 = response["event_id"]
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
"GET",
"/sync?since={}".format(vector_clock_token),
access_token=access_token,
)
- self.render_on_worker(sync_hs, request)
self.assertNotIn(room_id1, channel.json_body["rooms"]["join"])
self.assertIn(room_id2, channel.json_body["rooms"]["join"])
@@ -254,10 +261,13 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
- request, channel = self.make_request(
- "GET", "/sync?since={}".format(next_batch), access_token=access_token
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/sync?since={}".format(next_batch),
+ access_token=access_token,
)
- self.render_on_worker(sync_hs, request)
prev_batch1 = channel.json_body["rooms"]["join"][room_id1]["timeline"][
"prev_batch"
@@ -269,50 +279,54 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back in the first room should not produce any results, as
# no events have happened in it. This tests that we are correctly
# filtering results based on the vector clock portion.
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id1, prev_batch1, vector_clock_token
),
access_token=access_token,
)
- self.render_on_worker(sync_hs, request)
self.assertListEqual([], channel.json_body["chunk"])
# Paginating back on the second room should produce the first event
# again. This tests that pagination isn't completely broken.
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id2, prev_batch2, vector_clock_token
),
access_token=access_token,
)
- self.render_on_worker(sync_hs, request)
self.assertEqual(len(channel.json_body["chunk"]), 1)
self.assertEqual(
channel.json_body["chunk"][0]["event_id"], first_event_in_room2
)
# Paginating forwards should give the same results
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id1, vector_clock_token, prev_batch1
),
access_token=access_token,
)
- self.render_on_worker(sync_hs, request)
self.assertListEqual([], channel.json_body["chunk"])
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id2, vector_clock_token, prev_batch2,
),
access_token=access_token,
)
- self.render_on_worker(sync_hs, request)
self.assertEqual(len(channel.json_body["chunk"]), 1)
self.assertEqual(
channel.json_body["chunk"][0]["event_id"], first_event_in_room2
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 0f1144fe1e..4f76f8f768 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -30,19 +30,19 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import groups
from tests import unittest
+from tests.server import FakeSite, make_request
class VersionTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/server_version"
- def create_test_json_resource(self):
+ def create_test_resource(self):
resource = JsonResource(self.hs)
VersionServlet(self.hs).register(resource)
return resource
def test_version_string(self):
request, channel = self.make_request("GET", self.url, shorthand=False)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -75,7 +75,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
content={"localpart": "test"},
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
group_id = channel.json_body["group_id"]
@@ -88,14 +87,12 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={}
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
url = "/groups/%s/self/accept_invite" % (group_id,)
request, channel = self.make_request(
"PUT", url.encode("ascii"), access_token=self.other_user_token, content={}
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Check other user knows they're in the group
@@ -103,7 +100,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
self.assertIn(group_id, self._get_groups_user_is_in(self.other_user_token))
# Now delete the group
- url = "/admin/delete_group/" + group_id
+ url = "/_synapse/admin/v1/delete_group/" + group_id
request, channel = self.make_request(
"POST",
url.encode("ascii"),
@@ -111,7 +108,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
content={"localpart": "test"},
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Check group returns 404
@@ -131,7 +127,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
- self.render(request)
self.assertEqual(
expect_code, int(channel.result["code"]), msg=channel.result["body"]
)
@@ -143,7 +138,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
"GET", "/joined_groups".encode("ascii"), access_token=access_token
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
return channel.json_body["groups"]
@@ -222,11 +216,14 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
"""Ensure a piece of media is quarantined when trying to access it."""
- request, channel = self.make_request(
- "GET", server_and_media_id, shorthand=False, access_token=admin_user_tok,
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(self.download_resource),
+ "GET",
+ server_and_media_id,
+ shorthand=False,
+ access_token=admin_user_tok,
)
- request.render(self.download_resource)
- self.pump(1.0)
# Should be quarantined
self.assertEqual(
@@ -247,7 +244,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", url.encode("ascii"), access_token=non_admin_user_tok,
)
- self.render(request)
# Expect a forbidden error
self.assertEqual(
@@ -261,7 +257,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", url.encode("ascii"), access_token=non_admin_user_tok,
)
- self.render(request)
# Expect a forbidden error
self.assertEqual(
@@ -287,14 +282,14 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
server_name, media_id = server_name_and_media_id.split("/")
# Attempt to access the media
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(self.download_resource),
"GET",
server_name_and_media_id,
shorthand=False,
access_token=non_admin_user_tok,
)
- request.render(self.download_resource)
- self.pump(1.0)
# Should be successful
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
@@ -305,7 +300,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
urllib.parse.quote(media_id),
)
request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
- self.render(request)
self.pump(1.0)
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
@@ -358,7 +352,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
room_id
)
request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
- self.render(request)
self.pump(1.0)
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
self.assertEqual(
@@ -405,7 +398,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", url.encode("ascii"), access_token=admin_user_tok,
)
- self.render(request)
self.pump(1.0)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -448,7 +440,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", url.encode("ascii"), access_token=admin_user_tok,
)
- self.render(request)
self.pump(1.0)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -462,14 +453,14 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
# Attempt to access each piece of media
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(self.download_resource),
"GET",
server_and_media_id_2,
shorthand=False,
access_token=non_admin_user_tok,
)
- request.render(self.download_resource)
- self.pump(1.0)
# Shouldn't be quarantined
self.assertEqual(
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index d89eb90cfe..cf3a007598 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -51,19 +51,16 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
Try to get a device of an user without authentication.
"""
request, channel = self.make_request("GET", self.url, b"{}")
- self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
request, channel = self.make_request("PUT", self.url, b"{}")
- self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
request, channel = self.make_request("DELETE", self.url, b"{}")
- self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -75,7 +72,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.other_user_token,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -83,7 +79,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", self.url, access_token=self.other_user_token,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -91,7 +86,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"DELETE", self.url, access_token=self.other_user_token,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -108,7 +102,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -116,7 +109,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -124,7 +116,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"DELETE", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -141,7 +132,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -149,7 +139,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -157,7 +146,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"DELETE", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -173,7 +161,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -181,14 +168,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
request, channel = self.make_request(
"DELETE", url, access_token=self.admin_user_tok,
)
- self.render(request)
# Delete unknown device returns status 200
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -218,7 +203,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
@@ -227,7 +211,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
@@ -247,7 +230,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -255,7 +237,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
@@ -272,7 +253,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -280,7 +260,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new displayname", channel.json_body["display_name"])
@@ -292,7 +271,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
@@ -316,7 +294,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"DELETE", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -347,7 +324,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
Try to list devices of an user without authentication.
"""
request, channel = self.make_request("GET", self.url, b"{}")
- self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -361,7 +337,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=other_user_token,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -374,7 +349,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -388,7 +362,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -403,7 +376,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -422,7 +394,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_devices, channel.json_body["total"])
@@ -461,7 +432,6 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
Try to delete devices of an user without authentication.
"""
request, channel = self.make_request("POST", self.url, b"{}")
- self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -475,7 +445,6 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", self.url, access_token=other_user_token,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -488,7 +457,6 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -502,7 +470,6 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -518,7 +485,6 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
# Delete unknown devices returns status 200
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -550,7 +516,6 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 303622217f..11b72c10f7 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -75,7 +75,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
Try to get an event report without authentication.
"""
request, channel = self.make_request("GET", self.url, b"{}")
- self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -88,7 +87,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.other_user_tok,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -101,7 +99,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -117,7 +114,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?limit=5", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -133,7 +129,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?from=5", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -149,7 +144,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -167,7 +161,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.url + "?room_id=%s" % self.room_id1,
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 10)
@@ -188,7 +181,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.url + "?user_id=%s" % self.other_user,
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 10)
@@ -209,7 +201,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 5)
@@ -230,7 +221,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?dir=b", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -247,7 +237,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?dir=f", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -268,7 +257,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -282,7 +270,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -295,7 +282,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?from=-5", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -310,7 +296,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?limit=20", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -322,7 +307,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?limit=21", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -334,7 +318,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?limit=19", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -347,7 +330,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?from=19", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -366,7 +348,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
json.dumps({"score": -100, "reason": "this makes me sad"}),
access_token=user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
def _check_fields(self, content):
@@ -419,7 +400,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
Try to get event report without authentication.
"""
request, channel = self.make_request("GET", self.url, b"{}")
- self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -432,7 +412,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.other_user_tok,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -445,7 +424,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self._check_fields(channel.json_body)
@@ -461,7 +439,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/event_reports/-123",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -476,7 +453,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/event_reports/abcdef",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -491,7 +467,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/event_reports/",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -510,7 +485,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/event_reports/123",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -528,7 +502,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
json.dumps({"score": -100, "reason": "this makes me sad"}),
access_token=user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
def _check_fields(self, content):
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 721fa1ed51..dadf9db660 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, profile, room
from synapse.rest.media.v1.filepath import MediaFilePaths
from tests import unittest
+from tests.server import FakeSite, make_request
class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
@@ -50,7 +51,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
request, channel = self.make_request("DELETE", url, b"{}")
- self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -67,7 +67,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"DELETE", url, access_token=self.other_user_token,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -81,7 +80,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"DELETE", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -95,7 +93,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"DELETE", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
@@ -124,14 +121,14 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertEqual(server_name, self.server_name)
# Attempt to access media
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(download_resource),
"GET",
server_and_media_id,
shorthand=False,
access_token=self.admin_user_tok,
)
- request.render(download_resource)
- self.pump(1.0)
# Should be successful
self.assertEqual(
@@ -152,7 +149,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"DELETE", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
@@ -161,14 +157,14 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
)
# Attempt to access media
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(download_resource),
"GET",
server_and_media_id,
shorthand=False,
access_token=self.admin_user_tok,
)
- request.render(download_resource)
- self.pump(1.0)
self.assertEqual(
404,
channel.code,
@@ -196,7 +192,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.handler = hs.get_device_handler()
self.media_repo = hs.get_media_repository_resource()
self.server_name = hs.hostname
- self.clock = hs.clock
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -210,7 +205,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
"""
request, channel = self.make_request("POST", self.url, b"{}")
- self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -225,7 +219,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", self.url, access_token=self.other_user_token,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -239,7 +232,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", url + "?before_ts=1234", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
@@ -251,7 +243,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
@@ -266,7 +257,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -280,7 +270,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=1234&size_gt=-1234",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -294,7 +283,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=1234&keep_profiles=not_bool",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
@@ -325,7 +313,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
@@ -350,7 +337,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -363,7 +349,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
@@ -387,7 +372,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&size_gt=67",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -399,7 +383,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&size_gt=66",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
@@ -424,7 +407,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec()
@@ -433,7 +415,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -445,7 +426,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
@@ -471,7 +451,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec()
@@ -480,7 +459,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -492,7 +470,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
@@ -535,14 +512,14 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
media_id = server_and_media_id.split("/")[1]
local_path = self.filepaths.local_media_filepath(media_id)
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(download_resource),
"GET",
server_and_media_id,
shorthand=False,
access_token=self.admin_user_tok,
)
- request.render(download_resource)
- self.pump(1.0)
if expect_success:
self.assertEqual(
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 535d68f284..46933a0493 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -78,14 +78,13 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
)
# Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
+ url = "/_synapse/admin/v1/shutdown_room/" + room_id
request, channel = self.make_request(
"POST",
url.encode("ascii"),
json.dumps({"new_room_user_id": self.admin_user}),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -110,18 +109,16 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
json.dumps({"history_visibility": "world_readable"}),
access_token=self.other_user_token,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
+ url = "/_synapse/admin/v1/shutdown_room/" + room_id
request, channel = self.make_request(
"POST",
url.encode("ascii"),
json.dumps({"new_room_user_id": self.admin_user}),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -136,7 +133,6 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
- self.render(request)
self.assertEqual(
expect_code, int(channel.result["code"]), msg=channel.result["body"]
)
@@ -145,7 +141,6 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
- self.render(request)
self.assertEqual(
expect_code, int(channel.result["code"]), msg=channel.result["body"]
)
@@ -192,7 +187,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", self.url, json.dumps({}), access_token=self.other_user_tok,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -206,7 +200,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", url, json.dumps({}), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -220,7 +213,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", url, json.dumps({}), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -239,7 +231,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertIn("new_room_id", channel.json_body)
@@ -259,7 +250,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -278,7 +268,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
@@ -295,7 +284,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
@@ -322,7 +310,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(None, channel.json_body["new_room_id"])
@@ -356,7 +343,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(None, channel.json_body["new_room_id"])
@@ -391,7 +377,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(None, channel.json_body["new_room_id"])
@@ -439,7 +424,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
json.dumps({"new_room_user_id": self.admin_user}),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
@@ -470,7 +454,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
json.dumps({"history_visibility": "world_readable"}),
access_token=self.other_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Test that room is not purged
@@ -488,7 +471,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
json.dumps({"new_room_user_id": self.admin_user}),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
@@ -551,7 +533,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
- self.render(request)
self.assertEqual(
expect_code, int(channel.result["code"]), msg=channel.result["body"]
)
@@ -560,7 +541,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
- self.render(request)
self.assertEqual(
expect_code, int(channel.result["code"]), msg=channel.result["body"]
)
@@ -595,7 +575,6 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -647,7 +626,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
# Check request completed successfully
self.assertEqual(200, int(channel.code), msg=channel.json_body)
@@ -729,7 +707,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(
200, int(channel.result["code"]), msg=channel.result["body"]
)
@@ -770,7 +747,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
def test_correct_room_attributes(self):
@@ -794,7 +770,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Set this new alias as the canonical alias for this room
@@ -822,7 +797,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Check that rooms were returned
@@ -867,7 +841,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=admin_user_tok,
)
- self.render(request)
self.assertEqual(
200, int(channel.result["code"]), msg=channel.result["body"]
)
@@ -905,7 +878,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that rooms were returned
@@ -1042,7 +1014,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
if expected_http_code != 200:
@@ -1104,7 +1075,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("room_id", channel.json_body)
@@ -1152,7 +1122,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertCountEqual(
@@ -1164,7 +1133,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertCountEqual(
@@ -1208,7 +1176,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.second_tok,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1225,7 +1192,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
@@ -1242,7 +1208,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -1259,7 +1224,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -1280,7 +1244,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("No known servers", channel.json_body["error"])
@@ -1298,7 +1261,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -1318,7 +1280,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(self.public_room_id, channel.json_body["room_id"])
@@ -1328,7 +1289,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
)
- self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
@@ -1349,7 +1309,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1377,7 +1336,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
@@ -1392,7 +1350,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["room_id"])
@@ -1401,7 +1358,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
)
- self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
@@ -1422,7 +1378,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["room_id"])
@@ -1432,7 +1387,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
)
- self.render(request)
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 816683a612..907b49f889 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -47,7 +47,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
Try to list users without authentication.
"""
request, channel = self.make_request("GET", self.url, b"{}")
- self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -59,7 +58,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, json.dumps({}), access_token=self.other_user_tok,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -72,7 +70,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?order_by=bar", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -81,7 +78,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?from=-5", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -90,7 +86,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -99,7 +94,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?from_ts=-1234", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -108,7 +102,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?until_ts=-1234", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -119,7 +112,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?from_ts=10&until_ts=5",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -128,7 +120,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?search_term=", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -137,7 +128,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -151,7 +141,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?limit=5", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 10)
@@ -168,7 +157,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?from=5", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -185,7 +173,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -206,7 +193,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?limit=20", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
@@ -218,7 +204,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?limit=21", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
@@ -230,7 +215,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?limit=19", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
@@ -242,7 +226,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?from=19", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
@@ -258,7 +241,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -337,7 +319,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
@@ -346,7 +327,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 0)
@@ -362,7 +342,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2),
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
@@ -370,7 +349,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["users"][0]["media_count"], 6)
@@ -381,7 +359,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 20)
@@ -391,7 +368,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=foo_user_1",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 11)
@@ -401,7 +377,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=bar_user_10",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10")
self.assertEqual(channel.json_body["total"], 1)
@@ -410,7 +385,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?search_term=foobar", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], 0)
@@ -476,7 +450,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_user_list))
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index d74efede06..54d46f4bd3 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -24,8 +24,8 @@ from mock import Mock
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
-from synapse.rest.client.v1 import login, profile, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client.v1 import login, logout, profile, room
+from synapse.rest.client.v2_alpha import devices, sync
from tests import unittest
from tests.test_utils import make_awaitable
@@ -41,7 +41,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- self.url = "/_matrix/client/r0/admin/register"
+ self.url = "/_synapse/admin/v1/register"
self.registration_handler = Mock()
self.identity_handler = Mock()
@@ -71,7 +71,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.hs.config.registration_shared_secret = None
request, channel = self.make_request("POST", self.url, b"{}")
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -89,7 +88,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.hs.get_secrets = Mock(return_value=secrets)
request, channel = self.make_request("GET", self.url)
- self.render(request)
self.assertEqual(channel.json_body, {"nonce": "abcd"})
@@ -99,7 +97,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
only last for SALT_TIMEOUT (60s).
"""
request, channel = self.make_request("GET", self.url)
- self.render(request)
nonce = channel.json_body["nonce"]
# 59 seconds
@@ -107,7 +104,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = json.dumps({"nonce": nonce})
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("username must be specified", channel.json_body["error"])
@@ -116,7 +112,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.reactor.advance(2)
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -126,7 +121,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
Only the provided nonce can be used, as it's checked in the MAC.
"""
request, channel = self.make_request("GET", self.url)
- self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -143,7 +137,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
)
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("HMAC incorrect", channel.json_body["error"])
@@ -154,7 +147,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
user is registered.
"""
request, channel = self.make_request("GET", self.url)
- self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -174,7 +166,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
)
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -184,7 +175,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
A valid unrecognised nonce.
"""
request, channel = self.make_request("GET", self.url)
- self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -201,14 +191,12 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
)
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -222,7 +210,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
def nonce():
request, channel = self.make_request("GET", self.url)
- self.render(request)
return channel.json_body["nonce"]
#
@@ -232,7 +219,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
body = json.dumps({})
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("nonce must be specified", channel.json_body["error"])
@@ -244,7 +230,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
body = json.dumps({"nonce": nonce()})
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("username must be specified", channel.json_body["error"])
@@ -252,7 +237,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be a string
body = json.dumps({"nonce": nonce(), "username": 1234})
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid username", channel.json_body["error"])
@@ -260,7 +244,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"})
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid username", channel.json_body["error"])
@@ -268,7 +251,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid username", channel.json_body["error"])
@@ -280,7 +262,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
body = json.dumps({"nonce": nonce(), "username": "a"})
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("password must be specified", channel.json_body["error"])
@@ -288,7 +269,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be a string
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid password", channel.json_body["error"])
@@ -296,7 +276,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"})
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid password", channel.json_body["error"])
@@ -304,7 +283,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Super long
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid password", channel.json_body["error"])
@@ -323,7 +301,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
)
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Invalid user type", channel.json_body["error"])
@@ -335,7 +312,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# set no displayname
request, channel = self.make_request("GET", self.url)
- self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -346,19 +322,16 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
{"nonce": nonce, "username": "bob1", "password": "abc123", "mac": want_mac}
)
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob1:test", channel.json_body["user_id"])
request, channel = self.make_request("GET", "/profile/@bob1:test/displayname")
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("bob1", channel.json_body["displayname"])
# displayname is None
request, channel = self.make_request("GET", self.url)
- self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -375,19 +348,16 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
)
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob2:test", channel.json_body["user_id"])
request, channel = self.make_request("GET", "/profile/@bob2:test/displayname")
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("bob2", channel.json_body["displayname"])
# displayname is empty
request, channel = self.make_request("GET", self.url)
- self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -404,18 +374,15 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
)
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob3:test", channel.json_body["user_id"])
request, channel = self.make_request("GET", "/profile/@bob3:test/displayname")
- self.render(request)
self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
# set displayname
request, channel = self.make_request("GET", self.url)
- self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -432,13 +399,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
)
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob4:test", channel.json_body["user_id"])
request, channel = self.make_request("GET", "/profile/@bob4:test/displayname")
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("Bob's Name", channel.json_body["displayname"])
@@ -465,7 +430,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Register new user with admin API
request, channel = self.make_request("GET", self.url)
- self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -485,7 +449,6 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
)
request, channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -511,7 +474,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
Try to list users without authentication.
"""
request, channel = self.make_request("GET", self.url, b"{}")
- self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"])
@@ -526,7 +488,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
b"{}",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(3, len(channel.json_body["users"]))
@@ -562,7 +523,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url, access_token=self.other_user_token,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("You are not a server admin", channel.json_body["error"])
@@ -570,7 +530,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", url, access_token=self.other_user_token, content=b"{}",
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("You are not a server admin", channel.json_body["error"])
@@ -585,7 +544,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v2/users/@unknown_person:test",
access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
@@ -613,7 +571,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -626,7 +583,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -659,7 +615,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -672,7 +627,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -700,7 +654,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/sync", access_token=self.admin_user_tok
)
- self.render(request)
if channel.code != 200:
raise HttpResponseException(
@@ -729,7 +682,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -769,7 +721,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
# Admin user is not blocked by mau anymore
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
@@ -807,7 +758,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -852,7 +802,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -879,7 +828,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -897,7 +845,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
@@ -907,7 +854,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
@@ -929,7 +875,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
@@ -940,7 +885,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
@@ -961,7 +905,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
@@ -972,7 +915,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
@@ -990,7 +932,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self._is_erased("@user:test", False)
d = self.store.mark_user_erased("@user:test")
@@ -1004,7 +945,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
# Reactivate the user.
@@ -1016,14 +956,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
encoding="utf_8"
),
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Get user
request, channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
@@ -1044,7 +982,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
@@ -1054,7 +991,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
@@ -1076,7 +1012,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -1086,7 +1021,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -1102,7 +1036,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -1110,7 +1043,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
@@ -1153,7 +1085,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
Try to list rooms of an user without authentication.
"""
request, channel = self.make_request("GET", self.url, b"{}")
- self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1167,7 +1098,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=other_user_token,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1180,7 +1110,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -1194,7 +1123,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -1208,7 +1136,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -1228,7 +1155,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_rooms, channel.json_body["total"])
@@ -1258,7 +1184,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
Try to list pushers of an user without authentication.
"""
request, channel = self.make_request("GET", self.url, b"{}")
- self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1272,7 +1197,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=other_user_token,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1285,7 +1209,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -1299,7 +1222,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -1313,7 +1235,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -1343,7 +1264,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
@@ -1383,7 +1303,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
Try to list media of an user without authentication.
"""
request, channel = self.make_request("GET", self.url, b"{}")
- self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1397,7 +1316,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=other_user_token,
)
- self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1410,7 +1328,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -1424,7 +1341,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -1441,7 +1357,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?limit=5", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_media)
@@ -1461,7 +1376,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?from=5", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_media)
@@ -1481,7 +1395,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_media)
@@ -1497,7 +1410,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -1510,7 +1422,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?from=-5", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
@@ -1529,7 +1440,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?limit=20", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_media)
@@ -1541,7 +1451,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?limit=21", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_media)
@@ -1553,7 +1462,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?limit=19", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_media)
@@ -1566,7 +1474,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url + "?from=19", access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_media)
@@ -1582,7 +1489,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
@@ -1600,7 +1506,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
- self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_media, channel.json_body["total"])
@@ -1638,3 +1543,336 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertIn("last_access_ts", m)
self.assertIn("quarantined_by", m)
self.assertIn("safe_from_quarantine", m)
+
+
+class UserTokenRestTestCase(unittest.HomeserverTestCase):
+ """Test for /_synapse/admin/v1/users/<user>/login
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ room.register_servlets,
+ devices.register_servlets,
+ logout.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+ self.url = "/_synapse/admin/v1/users/%s/login" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def _get_token(self) -> str:
+ request, channel = self.make_request(
+ "POST", self.url, b"{}", access_token=self.admin_user_tok
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ return channel.json_body["access_token"]
+
+ def test_no_auth(self):
+ """Try to login as a user without authentication.
+ """
+ request, channel = self.make_request("POST", self.url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_not_admin(self):
+ """Try to login as a user as a non-admin user.
+ """
+ request, channel = self.make_request(
+ "POST", self.url, b"{}", access_token=self.other_user_tok
+ )
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_send_event(self):
+ """Test that sending event as a user works.
+ """
+ # Create a room.
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
+
+ # Login in as the user
+ puppet_token = self._get_token()
+
+ # Test that sending works, and generates the event as the right user.
+ resp = self.helper.send_event(room_id, "com.example.test", tok=puppet_token)
+ event_id = resp["event_id"]
+ event = self.get_success(self.store.get_event(event_id))
+ self.assertEqual(event.sender, self.other_user)
+
+ def test_devices(self):
+ """Tests that logging in as a user doesn't create a new device for them.
+ """
+ # Login in as the user
+ self._get_token()
+
+ # Check that we don't see a new device in our devices list
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=self.other_user_tok
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # We should only see the one device (from the login in `prepare`)
+ self.assertEqual(len(channel.json_body["devices"]), 1)
+
+ def test_logout(self):
+ """Test that calling `/logout` with the token works.
+ """
+ # Login in as the user
+ puppet_token = self._get_token()
+
+ # Test that we can successfully make a request
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=puppet_token
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Logout with the puppet token
+ request, channel = self.make_request(
+ "POST", "logout", b"{}", access_token=puppet_token
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # The puppet token should no longer work
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=puppet_token
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+
+ # .. but the real user's tokens should still work
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=self.other_user_tok
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_user_logout_all(self):
+ """Tests that the target user calling `/logout/all` does *not* expire
+ the token.
+ """
+ # Login in as the user
+ puppet_token = self._get_token()
+
+ # Test that we can successfully make a request
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=puppet_token
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Logout all with the real user token
+ request, channel = self.make_request(
+ "POST", "logout/all", b"{}", access_token=self.other_user_tok
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # The puppet token should still work
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=puppet_token
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # .. but the real user's tokens shouldn't
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=self.other_user_tok
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_admin_logout_all(self):
+ """Tests that the admin user calling `/logout/all` does expire the
+ token.
+ """
+ # Login in as the user
+ puppet_token = self._get_token()
+
+ # Test that we can successfully make a request
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=puppet_token
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Logout all with the admin user token
+ request, channel = self.make_request(
+ "POST", "logout/all", b"{}", access_token=self.admin_user_tok
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # The puppet token should no longer work
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=puppet_token
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+
+ # .. but the real user's tokens should still work
+ request, channel = self.make_request(
+ "GET", "devices", b"{}", access_token=self.other_user_tok
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ @unittest.override_config(
+ {
+ "public_baseurl": "https://example.org/",
+ "user_consent": {
+ "version": "1.0",
+ "policy_name": "My Cool Privacy Policy",
+ "template_dir": "/",
+ "require_at_registration": True,
+ "block_events_error": "You should accept the policy",
+ },
+ "form_secret": "123secret",
+ }
+ )
+ def test_consent(self):
+ """Test that sending a message is not subject to the privacy policies.
+ """
+ # Have the admin user accept the terms.
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0"))
+
+ # First, cheekily accept the terms and create a room
+ self.get_success(self.store.user_set_consent_version(self.other_user, "1.0"))
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
+ self.helper.send_event(room_id, "com.example.test", tok=self.other_user_tok)
+
+ # Now unaccept it and check that we can't send an event
+ self.get_success(self.store.user_set_consent_version(self.other_user, "0.0"))
+ self.helper.send_event(
+ room_id, "com.example.test", tok=self.other_user_tok, expect_code=403
+ )
+
+ # Login in as the user
+ puppet_token = self._get_token()
+
+ # Sending an event on their behalf should work fine
+ self.helper.send_event(room_id, "com.example.test", tok=puppet_token)
+
+ @override_config(
+ {"limit_usage_by_mau": True, "max_mau_value": 1, "mau_trial_days": 0}
+ )
+ def test_mau_limit(self):
+ # Create a room as the admin user. This will bump the monthly active users to 1.
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # Trying to join as the other user should fail due to reaching MAU limit.
+ self.helper.join(
+ room_id, user=self.other_user, tok=self.other_user_tok, expect_code=403
+ )
+
+ # Logging in as the other user and joining a room should work, even
+ # though the MAU limit would stop the user doing so.
+ puppet_token = self._get_token()
+ self.helper.join(room_id, user=self.other_user, tok=puppet_token)
+
+
+class WhoisRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.url1 = "/_synapse/admin/v1/whois/%s" % urllib.parse.quote(self.other_user)
+ self.url2 = "/_matrix/client/r0/admin/whois/%s" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to get information of an user without authentication.
+ """
+ request, channel = self.make_request("GET", self.url1, b"{}")
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ request, channel = self.make_request("GET", self.url2, b"{}")
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_not_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ self.register_user("user2", "pass")
+ other_user2_token = self.login("user2", "pass")
+
+ request, channel = self.make_request(
+ "GET", self.url1, access_token=other_user2_token,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ request, channel = self.make_request(
+ "GET", self.url2, access_token=other_user2_token,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
+ url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain"
+
+ request, channel = self.make_request(
+ "GET", url1, access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only whois a local user", channel.json_body["error"])
+
+ request, channel = self.make_request(
+ "GET", url2, access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only whois a local user", channel.json_body["error"])
+
+ def test_get_whois_admin(self):
+ """
+ The lookup should succeed for an admin.
+ """
+ request, channel = self.make_request(
+ "GET", self.url1, access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+ self.assertIn("devices", channel.json_body)
+
+ request, channel = self.make_request(
+ "GET", self.url2, access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+ self.assertIn("devices", channel.json_body)
+
+ def test_get_whois_user(self):
+ """
+ The lookup should succeed for a normal user looking up their own information.
+ """
+ other_user_token = self.login("user", "pass")
+
+ request, channel = self.make_request(
+ "GET", self.url1, access_token=other_user_token,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+ self.assertIn("devices", channel.json_body)
+
+ request, channel = self.make_request(
+ "GET", self.url2, access_token=other_user_token,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+ self.assertIn("devices", channel.json_body)
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index 6803b372ac..e2e6a5e16d 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -21,7 +21,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.consent import consent_resource
from tests import unittest
-from tests.server import render
+from tests.server import FakeSite, make_request
class ConsentResourceTestCase(unittest.HomeserverTestCase):
@@ -61,8 +61,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
def test_render_public_consent(self):
"""You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs)
- request, channel = self.make_request("GET", "/consent?v=1", shorthand=False)
- render(request, resource, self.reactor)
+ request, channel = make_request(
+ self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
+ )
self.assertEqual(channel.code, 200)
def test_accept_consent(self):
@@ -81,10 +82,14 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "")
+ "&u=user"
)
- request, channel = self.make_request(
- "GET", consent_uri, access_token=access_token, shorthand=False
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(resource),
+ "GET",
+ consent_uri,
+ access_token=access_token,
+ shorthand=False,
)
- render(request, resource, self.reactor)
self.assertEqual(channel.code, 200)
# Get the version from the body, and whether we've consented
@@ -92,21 +97,26 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
self.assertEqual(consented, "False")
# POST to the consent page, saying we've agreed
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(resource),
"POST",
consent_uri + "&v=" + version,
access_token=access_token,
shorthand=False,
)
- render(request, resource, self.reactor)
self.assertEqual(channel.code, 200)
# Fetch the consent page, to get the consent version -- it should have
# changed
- request, channel = self.make_request(
- "GET", consent_uri, access_token=access_token, shorthand=False
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(resource),
+ "GET",
+ consent_uri,
+ access_token=access_token,
+ shorthand=False,
)
- render(request, resource, self.reactor)
self.assertEqual(channel.code, 200)
# Get the version from the body, and check that it's the version we
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
index 5e9c07ebf3..a1ccc4ee9a 100644
--- a/tests/rest/client/test_ephemeral_message.py
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -94,7 +94,6 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
request, channel = self.make_request("GET", url)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index c973521907..aa094c3543 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -15,15 +15,22 @@
import json
+from mock import Mock
+
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import account
from tests import unittest
-class IdentityTestCase(unittest.HomeserverTestCase):
+class IdentityDisabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts fail when the HS's config disallows them."""
servlets = [
+ account.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
@@ -32,21 +39,95 @@ class IdentityTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["testis"]
config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
return self.hs
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_disabled(self):
+ request, channel = self.make_request(
+ b"POST", "/createRoom", b"{}", access_token=self.tok
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ room_id = channel.json_body["room_id"]
+
+ params = {
+ "id_server": "testis",
+ "medium": "email",
+ "address": "test@example.com",
+ }
+ request_data = json.dumps(params)
+ request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
+ request, channel = self.make_request(
+ b"POST", request_url, request_data, access_token=self.tok
+ )
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
def test_3pid_lookup_disabled(self):
- self.hs.config.enable_3pid_lookup = False
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ request, channel = self.make_request("GET", url, access_token=self.tok)
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+
+ def test_3pid_bulk_lookup_disabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=self.tok
+ )
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+
+
+class IdentityEnabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts succeed when the HS's config allows them."""
+
+ servlets = [
+ account.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+
+ config = self.default_config()
+ config["enable_3pid_lookup"] = True
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.get_json.return_value = defer.succeed({"mxid": "@f:test"})
+ mock_http_client.post_json_get_json.return_value = defer.succeed({})
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_identity_handler().http_client = mock_http_client
+
+ return self.hs
- self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+ def test_3pid_invite_enabled(self):
request, channel = self.make_request(
- b"POST", "/createRoom", b"{}", access_token=tok
+ b"POST", "/createRoom", b"{}", access_token=self.tok
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
room_id = channel.json_body["room_id"]
@@ -58,7 +139,40 @@ class IdentityTestCase(unittest.HomeserverTestCase):
request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
request, channel = self.make_request(
- b"POST", request_url, request_data, access_token=tok
+ b"POST", request_url, request_data, access_token=self.tok
+ )
+
+ get_json = self.hs.get_identity_handler().http_client.get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "test@example.com", "medium": "email"},
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def test_3pid_lookup_enabled(self):
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ self.make_request("GET", url, access_token=self.tok)
+
+ get_json = self.hs.get_simple_http_client().get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "foo@bar.baz", "medium": "email"},
+ )
+
+ def test_3pid_bulk_lookup_enabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ request_data = json.dumps(data)
+ self.make_request("POST", url, request_data, access_token=self.tok)
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ post_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/bulk_lookup",
+ {"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]]},
)
- self.render(request)
- self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index d2bcf256fa..c1f516cc93 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -72,7 +72,6 @@ class RedactionsTestCase(HomeserverTestCase):
request, channel = self.make_request(
"POST", path, content={}, access_token=access_token
)
- self.render(request)
self.assertEqual(int(channel.result["code"]), expect_code)
return channel.json_body
@@ -80,7 +79,6 @@ class RedactionsTestCase(HomeserverTestCase):
request, channel = self.make_request(
"GET", "sync", access_token=self.mod_access_token
)
- self.render(request)
self.assertEqual(channel.result["code"], b"200")
room_sync = channel.json_body["rooms"]["join"][room_id]
return room_sync["timeline"]["events"]
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 7d3773ff78..62b41ada42 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -34,6 +34,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["default_room_version"] = "1"
config["retention"] = {
"enabled": True,
"default_policy": {
@@ -243,6 +244,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["default_room_version"] = "1"
config["retention"] = {
"enabled": True,
}
@@ -326,7 +328,6 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
request, channel = self.make_request("GET", url, access_token=self.token)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py
new file mode 100644
index 0000000000..f02a613a28
--- /dev/null
+++ b/tests/rest/client/test_room_access_rules.py
@@ -0,0 +1,1070 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import random
+import string
+from typing import Optional
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset
+from synapse.api.errors import SynapseError
+from synapse.rest import admin
+from synapse.rest.client.v1 import directory, login, room
+from synapse.third_party_rules.access_rules import (
+ ACCESS_RULES_TYPE,
+ AccessRules,
+ RoomAccessRules,
+)
+from synapse.types import JsonDict, create_requester
+
+from tests import unittest
+
+
+class RoomAccessTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ config["third_party_event_rules"] = {
+ "module": "synapse.third_party_rules.access_rules.RoomAccessRules",
+ "config": {
+ "domains_forbidden_when_restricted": ["forbidden_domain"],
+ "id_server": "testis",
+ },
+ }
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ def send_invite(destination, room_id, event_id, pdu):
+ return defer.succeed(pdu)
+
+ def get_json(uri, args={}, headers=None):
+ address_domain = args["address"].split("@")[1]
+ return defer.succeed({"hs": address_domain})
+
+ def post_json_get_json(uri, post_json, args={}, headers=None):
+ token = "".join(random.choice(string.ascii_letters) for _ in range(10))
+ return defer.succeed(
+ {
+ "token": token,
+ "public_keys": [
+ {
+ "public_key": "serverpublickey",
+ "key_validity_url": "https://testis/pubkey/isvalid",
+ },
+ {
+ "public_key": "phemeralpublickey",
+ "key_validity_url": "https://testis/pubkey/ephemeral/isvalid",
+ },
+ ],
+ "display_name": "f...@b...",
+ }
+ )
+
+ mock_federation_client = Mock(spec=["send_invite"])
+ mock_federation_client.send_invite.side_effect = send_invite
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"],)
+ # Mocking the response for /info on the IS API.
+ mock_http_client.get_json.side_effect = get_json
+ # Mocking the response for /store-invite on the IS API.
+ mock_http_client.post_json_get_json.side_effect = post_json_get_json
+ self.hs = self.setup_test_homeserver(
+ config=config,
+ federation_client=mock_federation_client,
+ simple_http_client=mock_http_client,
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_identity_handler().blacklisting_http_client = mock_http_client
+
+ self.third_party_event_rules = self.hs.get_third_party_event_rules()
+
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ self.restricted_room = self.create_room()
+ self.unrestricted_room = self.create_room(rule=AccessRules.UNRESTRICTED)
+ self.direct_rooms = [
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ ]
+
+ self.invitee_id = self.register_user("invitee", "test")
+ self.invitee_tok = self.login("invitee", "test")
+
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ )
+
+ def test_create_room_no_rule(self):
+ """Tests that creating a room with no rule will set the default."""
+ room_id = self.create_room()
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, AccessRules.RESTRICTED)
+
+ def test_create_room_direct_no_rule(self):
+ """Tests that creating a direct room with no rule will set the default."""
+ room_id = self.create_room(direct=True)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, AccessRules.DIRECT)
+
+ def test_create_room_valid_rule(self):
+ """Tests that creating a room with a valid rule will set the right."""
+ room_id = self.create_room(rule=AccessRules.UNRESTRICTED)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, AccessRules.UNRESTRICTED)
+
+ def test_create_room_invalid_rule(self):
+ """Tests that creating a room with an invalid rule will set fail."""
+ self.create_room(rule=AccessRules.DIRECT, expected_code=400)
+
+ def test_create_room_direct_invalid_rule(self):
+ """Tests that creating a direct room with an invalid rule will fail.
+ """
+ self.create_room(direct=True, rule=AccessRules.RESTRICTED, expected_code=400)
+
+ def test_create_room_default_power_level_rules(self):
+ """Tests that a room created with no power level overrides instead uses the dinum
+ defaults
+ """
+ room_id = self.create_room(direct=True, rule=AccessRules.DIRECT)
+ power_levels = self.helper.get_state(room_id, "m.room.power_levels", self.tok)
+
+ # Inviting another user should require PL50, even in private rooms
+ self.assertEqual(power_levels["invite"], 50)
+ # Sending arbitrary state events should require PL100
+ self.assertEqual(power_levels["state_default"], 100)
+
+ def test_create_room_fails_on_incorrect_power_level_rules(self):
+ """Tests that a room created with power levels lower than that required are rejected"""
+ modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id)
+ modified_power_levels["invite"] = 0
+ modified_power_levels["state_default"] = 50
+
+ self.create_room(
+ direct=True,
+ rule=AccessRules.DIRECT,
+ initial_state=[
+ {"type": "m.room.power_levels", "content": modified_power_levels}
+ ],
+ expected_code=400,
+ )
+
+ def test_create_room_with_missing_power_levels_use_default_values(self):
+ """
+ Tests that a room created with custom power levels, but without defining invite or state_default
+ succeeds, but the missing values are replaced with the defaults.
+ """
+
+ # Attempt to create a room without defining "invite" or "state_default"
+ modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id)
+ del modified_power_levels["invite"]
+ del modified_power_levels["state_default"]
+ room_id = self.create_room(
+ direct=True,
+ rule=AccessRules.DIRECT,
+ initial_state=[
+ {"type": "m.room.power_levels", "content": modified_power_levels}
+ ],
+ )
+
+ # This should succeed, but the defaults should be put in place instead
+ room_power_levels = self.helper.get_state(
+ room_id, "m.room.power_levels", self.tok
+ )
+ self.assertEqual(room_power_levels["invite"], 50)
+ self.assertEqual(room_power_levels["state_default"], 100)
+
+ # And now the same test, but using power_levels_content_override instead
+ # of initial_state (which takes a slightly different codepath)
+ modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id)
+ del modified_power_levels["invite"]
+ del modified_power_levels["state_default"]
+ room_id = self.create_room(
+ direct=True,
+ rule=AccessRules.DIRECT,
+ power_levels_content_override=modified_power_levels,
+ )
+
+ # This should succeed, but the defaults should be put in place instead
+ room_power_levels = self.helper.get_state(
+ room_id, "m.room.power_levels", self.tok
+ )
+ self.assertEqual(room_power_levels["invite"], 50)
+ self.assertEqual(room_power_levels["state_default"], 100)
+
+ def test_existing_room_can_change_power_levels(self):
+ """Tests that a room created with default power levels can have their power levels
+ dropped after room creation
+ """
+ # Creates a room with the default power levels
+ room_id = self.create_room(
+ direct=True, rule=AccessRules.DIRECT, expected_code=200,
+ )
+
+ # Attempt to drop invite and state_default power levels after the fact
+ room_power_levels = self.helper.get_state(
+ room_id, "m.room.power_levels", self.tok
+ )
+ room_power_levels["invite"] = 0
+ room_power_levels["state_default"] = 50
+ self.helper.send_state(
+ room_id, "m.room.power_levels", room_power_levels, self.tok
+ )
+
+ def test_public_room(self):
+ """Tests that it's only possible to have a room listed in the public room list
+ if the access rule is restricted.
+ """
+ # Creating a room with the public_chat preset should succeed and set the access
+ # rule to restricted.
+ preset_room_id = self.create_room(preset=RoomCreationPreset.PUBLIC_CHAT)
+ self.assertEqual(
+ self.current_rule_in_room(preset_room_id), AccessRules.RESTRICTED
+ )
+
+ # Creating a room with the public join rule in its initial state should succeed
+ # and set the access rule to restricted.
+ init_state_room_id = self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ]
+ )
+ self.assertEqual(
+ self.current_rule_in_room(init_state_room_id), AccessRules.RESTRICTED
+ )
+
+ # List preset_room_id in the public room list
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/directory/list/room/%s" % (preset_room_id,),
+ {"visibility": "public"},
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # List init_state_room_id in the public room list
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/directory/list/room/%s" % (init_state_room_id,),
+ {"visibility": "public"},
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Changing access rule to unrestricted should fail.
+ self.change_rule_in_room(
+ preset_room_id, AccessRules.UNRESTRICTED, expected_code=403
+ )
+ self.change_rule_in_room(
+ init_state_room_id, AccessRules.UNRESTRICTED, expected_code=403
+ )
+
+ # Changing access rule to direct should fail.
+ self.change_rule_in_room(preset_room_id, AccessRules.DIRECT, expected_code=403)
+ self.change_rule_in_room(
+ init_state_room_id, AccessRules.DIRECT, expected_code=403
+ )
+
+ # Creating a new room with the public_chat preset and an access rule of direct
+ # should fail.
+ self.create_room(
+ preset=RoomCreationPreset.PUBLIC_CHAT,
+ rule=AccessRules.DIRECT,
+ expected_code=400,
+ )
+
+ # Changing join rule to public in an direct room should fail.
+ self.change_join_rule_in_room(
+ self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403
+ )
+
+ def test_restricted(self):
+ """Tests that in restricted mode we're unable to invite users from blacklisted
+ servers but can invite other users.
+
+ Also tests that the room can be published to, and removed from, the public room
+ list.
+ """
+ # We can't invite a user from a forbidden HS.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can invite a user which HS isn't forbidden.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:allowed_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.restricted_room,
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.restricted_room,
+ expected_code=200,
+ )
+
+ # We are allowed to publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room
+ data = {"visibility": "public"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # We are allowed to remove the room from the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room
+ data = {"visibility": "private"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ def test_direct(self):
+ """Tests that, in direct mode, other users than the initial two can't be invited,
+ but the following scenario works:
+ * invited user joins the room
+ * invited user leaves the room
+ * room creator re-invites invited user
+
+ Tests that a user from a HS that's in the list of forbidden domains (to use
+ in restricted mode) can be invited.
+
+ Tests that the room cannot be published to the public room list.
+ """
+ not_invited_user = "@not_invited:forbidden_domain"
+
+ # We can't invite a new user to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # The invited user can join the room.
+ self.helper.join(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can leave the room.
+ self.helper.leave(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can be re-invited to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # If we're alone in the room and have always been the only member, we can invite
+ # someone.
+ self.helper.invite(
+ room=self.direct_rooms[1],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # Disable the 3pid invite ratelimiter
+ burst = self.hs.config.rc_third_party_invite.burst_count
+ per_second = self.hs.config.rc_third_party_invite.per_second
+ self.hs.config.rc_third_party_invite.burst_count = 10
+ self.hs.config.rc_third_party_invite.per_second = 0.1
+
+ # We can't send a 3PID invite to a room that already has two members.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[0],
+ expected_code=403,
+ )
+
+ # We can't send a 3PID invite to a room that already has a pending invite.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[1],
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to a room in which we've always been the only member.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to a room in which there's a 3PID invite.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=403,
+ )
+
+ self.hs.config.rc_third_party_invite.burst_count = burst
+ self.hs.config.rc_third_party_invite.per_second = per_second
+
+ # We can't publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.direct_rooms[0]
+ data = {"visibility": "public"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.assertEqual(channel.code, 403, channel.result)
+
+ def test_unrestricted(self):
+ """Tests that, in unrestricted mode, we can invite whoever we want, but we can
+ only change the power level of users that wouldn't be forbidden in restricted
+ mode.
+
+ Tests that the room cannot be published to the public room list.
+ """
+ # We can invite
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:not_forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a power level event that doesn't redefine the default PL or set a
+ # non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:not_forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a power level event that redefines the default PL and doesn't set
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={
+ "users": {self.user_id: 100, "@test:not_forbidden_domain": 10},
+ "users_default": 10,
+ },
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can't send a power level event that doesn't redefines the default PL but sets
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can't publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.unrestricted_room
+ data = {"visibility": "public"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.assertEqual(channel.code, 403, channel.result)
+
+ def test_change_rules(self):
+ """Tests that we can only change the current rule from restricted to
+ unrestricted.
+ """
+ # We can't change the rule from restricted to direct.
+ self.change_rule_in_room(
+ room_id=self.restricted_room, new_rule=AccessRules.DIRECT, expected_code=403
+ )
+
+ # We can change the rule from restricted to unrestricted.
+ # Note that this changes self.restricted_room to an unrestricted room
+ self.change_rule_in_room(
+ room_id=self.restricted_room,
+ new_rule=AccessRules.UNRESTRICTED,
+ expected_code=200,
+ )
+
+ # We can't change the rule from unrestricted to restricted.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=AccessRules.RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from unrestricted to direct.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=AccessRules.DIRECT,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to restricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=AccessRules.RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to unrestricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=AccessRules.UNRESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't publish a room to the public room list and then change its rule to
+ # unrestricted
+
+ # Create a restricted room
+ test_room_id = self.create_room(rule=AccessRules.RESTRICTED)
+
+ # Publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % test_room_id
+ data = {"visibility": "public"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Attempt to switch the room to "unrestricted"
+ self.change_rule_in_room(
+ room_id=test_room_id, new_rule=AccessRules.UNRESTRICTED, expected_code=403
+ )
+
+ # Attempt to switch the room to "direct"
+ self.change_rule_in_room(
+ room_id=test_room_id, new_rule=AccessRules.DIRECT, expected_code=403
+ )
+
+ def test_change_room_avatar(self):
+ """Tests that changing the room avatar is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ avatar_content = {
+ "info": {"h": 398, "mimetype": "image/jpeg", "size": 31037, "w": 394},
+ "url": "mxc://example.org/JWEIFJgwEIhweiWJE",
+ }
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_name(self):
+ """Tests that changing the room name is always allowed unless the room is a direct
+ chat, in which case it's forbidden.
+ """
+
+ name_content = {"name": "My super room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_topic(self):
+ """Tests that changing the room topic is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ topic_content = {"topic": "Welcome to this room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_revoke_3pid_invite_direct(self):
+ """Tests that revoking a 3PID invite doesn't cause the room access rules module to
+ confuse the revokation as a new 3PID invite.
+ """
+ invite_token = "sometoken"
+
+ invite_body = {
+ "display_name": "ker...@exa...",
+ "public_keys": [
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ },
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I",
+ },
+ ],
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ }
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body={},
+ tok=self.tok,
+ )
+
+ invite_token = "someothertoken"
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ def test_check_event_allowed(self):
+ """Tests that RoomAccessRules.check_event_allowed behaves accordingly.
+
+ It tests that:
+ * forbidden users cannot join restricted rooms.
+ * forbidden users can only join unrestricted rooms if they have an invite.
+ """
+ event_creator = self.hs.get_event_creation_handler()
+
+ # Test that forbidden users cannot join restricted rooms
+ requester = create_requester(self.user_id)
+ allowed_requester = create_requester("@user:allowed_domain")
+ forbidden_requester = create_requester("@user:forbidden_domain")
+
+ # Assert a join event from a forbidden user to a restricted room is rejected
+ self.get_failure(
+ event_creator.create_event(
+ forbidden_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.restricted_room,
+ "sender": forbidden_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": forbidden_requester.user.to_string(),
+ },
+ ),
+ SynapseError,
+ )
+
+ # A join event from an non-forbidden user to a restricted room is allowed
+ self.get_success(
+ event_creator.create_event(
+ allowed_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.restricted_room,
+ "sender": allowed_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": allowed_requester.user.to_string(),
+ },
+ )
+ )
+
+ # Test that forbidden users can only join unrestricted rooms if they have an invite
+
+ # A forbidden user without an invite should not be able to join an unrestricted room
+ self.get_failure(
+ event_creator.create_event(
+ forbidden_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.unrestricted_room,
+ "sender": forbidden_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": forbidden_requester.user.to_string(),
+ },
+ ),
+ SynapseError,
+ )
+
+ # However, if we then invite this user...
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=requester.user.to_string(),
+ targ=forbidden_requester.user.to_string(),
+ tok=self.tok,
+ )
+
+ # And create another join event, making sure that its context states it's coming
+ # in after the above invite was made...
+ # Then the forbidden user should be able to join!
+ self.get_success(
+ event_creator.create_event(
+ forbidden_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.unrestricted_room,
+ "sender": forbidden_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": forbidden_requester.user.to_string(),
+ },
+ )
+ )
+
+ def test_freezing_a_room(self):
+ """Tests that the power levels in a room change to prevent new events from
+ non-admin users when the last admin of a room leaves.
+ """
+
+ def freeze_room_with_id_and_power_levels(
+ room_id: str, custom_power_levels_content: Optional[JsonDict] = None,
+ ):
+ # Invite a user to the room, they join with PL 0
+ self.helper.invite(
+ room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok,
+ )
+
+ # Invitee joins the room
+ self.helper.join(
+ room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ )
+
+ if not custom_power_levels_content:
+ # Retrieve the room's current power levels event content
+ power_levels = self.helper.get_state(
+ room_id=room_id, event_type="m.room.power_levels", tok=self.tok,
+ )
+ else:
+ power_levels = custom_power_levels_content
+
+ # Override the room's power levels with the given power levels content
+ self.helper.send_state(
+ room_id=room_id,
+ event_type="m.room.power_levels",
+ body=custom_power_levels_content,
+ tok=self.tok,
+ )
+
+ # Ensure that the invitee leaving the room does not change the power levels
+ self.helper.leave(
+ room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ )
+
+ # Retrieve the new power levels of the room
+ new_power_levels = self.helper.get_state(
+ room_id=room_id, event_type="m.room.power_levels", tok=self.tok,
+ )
+
+ # Ensure they have not changed
+ self.assertDictEqual(power_levels, new_power_levels)
+
+ # Invite the user back again
+ self.helper.invite(
+ room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok,
+ )
+
+ # Invitee joins the room
+ self.helper.join(
+ room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ )
+
+ # Now the admin leaves the room
+ self.helper.leave(
+ room=room_id, user=self.user_id, tok=self.tok,
+ )
+
+ # Check the power levels again
+ new_power_levels = self.helper.get_state(
+ room_id=room_id, event_type="m.room.power_levels", tok=self.invitee_tok,
+ )
+
+ # Ensure that the new power levels prevent anyone but admins from sending
+ # certain events
+ self.assertEquals(new_power_levels["state_default"], 100)
+ self.assertEquals(new_power_levels["events_default"], 100)
+ self.assertEquals(new_power_levels["kick"], 100)
+ self.assertEquals(new_power_levels["invite"], 100)
+ self.assertEquals(new_power_levels["ban"], 100)
+ self.assertEquals(new_power_levels["redact"], 100)
+ self.assertDictEqual(new_power_levels["events"], {})
+ self.assertDictEqual(new_power_levels["users"], {self.user_id: 100})
+
+ # Ensure new users entering the room aren't going to immediately become admins
+ self.assertEquals(new_power_levels["users_default"], 0)
+
+ # Test that freezing a room with the default power level state event content works
+ room1 = self.create_room()
+ freeze_room_with_id_and_power_levels(room1)
+
+ # Test that freezing a room with a power level state event that is missing
+ # `state_default` and `event_default` keys behaves as expected
+ room2 = self.create_room()
+ freeze_room_with_id_and_power_levels(
+ room2,
+ {
+ "ban": 50,
+ "events": {
+ "m.room.avatar": 50,
+ "m.room.canonical_alias": 50,
+ "m.room.history_visibility": 100,
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ },
+ "invite": 0,
+ "kick": 50,
+ "redact": 50,
+ "users": {self.user_id: 100},
+ "users_default": 0,
+ # Explicitly remove `state_default` and `event_default` keys
+ },
+ )
+
+ # Test that freezing a room with a power level state event that is *additionally*
+ # missing `ban`, `invite`, `kick` and `redact` keys behaves as expected
+ room3 = self.create_room()
+ freeze_room_with_id_and_power_levels(
+ room3,
+ {
+ "events": {
+ "m.room.avatar": 50,
+ "m.room.canonical_alias": 50,
+ "m.room.history_visibility": 100,
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ },
+ "users": {self.user_id: 100},
+ "users_default": 0,
+ # Explicitly remove `state_default` and `event_default` keys
+ # Explicitly remove `ban`, `invite`, `kick` and `redact` keys
+ },
+ )
+
+ def create_room(
+ self,
+ direct=False,
+ rule=None,
+ preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+ initial_state=None,
+ power_levels_content_override=None,
+ expected_code=200,
+ ):
+ content = {"is_direct": direct, "preset": preset}
+
+ if rule:
+ content["initial_state"] = [
+ {"type": ACCESS_RULES_TYPE, "state_key": "", "content": {"rule": rule}}
+ ]
+
+ if initial_state:
+ if "initial_state" not in content:
+ content["initial_state"] = []
+
+ content["initial_state"] += initial_state
+
+ if power_levels_content_override:
+ content["power_levels_content_override"] = power_levels_content_override
+
+ request, channel = self.make_request(
+ "POST", "/_matrix/client/r0/createRoom", content, access_token=self.tok,
+ )
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ if expected_code == 200:
+ return channel.json_body["room_id"]
+
+ def current_rule_in_room(self, room_id):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ access_token=self.tok,
+ )
+
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body["rule"]
+
+ def change_rule_in_room(self, room_id, new_rule, expected_code=200):
+ data = {"rule": new_rule}
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200):
+ data = {"join_rule": new_join_rule}
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_threepid_invite(self, address, room_id, expected_code=200):
+ params = {"id_server": "testis", "medium": "email", "address": address}
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/invite" % room_id,
+ json.dumps(params),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_state_with_state_key(
+ self, room_id, event_type, state_key, body, tok, expect_code=200
+ ):
+ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
+ room_id,
+ event_type,
+ state_key,
+ )
+
+ request, channel = self.make_request(
+ "PUT", path, json.dumps(body), access_token=tok
+ )
+
+ self.assertEqual(channel.code, expect_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 6bb02b9630..94dcfb9f7c 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -95,7 +95,6 @@ class RoomTestCase(_ShadowBannedBase):
{"id_server": "test", "medium": "email", "address": "test@test.test"},
access_token=self.banned_access_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
# This should have raised an error earlier, but double check this wasn't called.
@@ -110,7 +109,6 @@ class RoomTestCase(_ShadowBannedBase):
{"visibility": "public", "invite": [self.other_user_id]},
access_token=self.banned_access_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
room_id = channel.json_body["room_id"]
@@ -166,7 +164,6 @@ class RoomTestCase(_ShadowBannedBase):
{"new_version": "6"},
access_token=self.banned_access_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
# A new room_id should be returned.
self.assertIn("replacement_room", channel.json_body)
@@ -192,7 +189,6 @@ class RoomTestCase(_ShadowBannedBase):
{"typing": True, "timeout": 30000},
access_token=self.banned_access_token,
)
- self.render(request)
self.assertEquals(200, channel.code)
# There should be no typing events.
@@ -208,7 +204,6 @@ class RoomTestCase(_ShadowBannedBase):
{"typing": True, "timeout": 30000},
access_token=self.other_access_token,
)
- self.render(request)
self.assertEquals(200, channel.code)
# These appear in the room.
@@ -255,7 +250,6 @@ class ProfileTestCase(_ShadowBannedBase):
{"displayname": new_display_name},
access_token=self.banned_access_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertEqual(channel.json_body, {})
@@ -263,7 +257,6 @@ class ProfileTestCase(_ShadowBannedBase):
request, channel = self.make_request(
"GET", "/profile/%s/displayname" % (self.banned_user_id,)
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["displayname"], new_display_name)
@@ -296,7 +289,6 @@ class ProfileTestCase(_ShadowBannedBase):
{"membership": "join", "displayname": new_display_name},
access_token=self.banned_access_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertIn("event_id", channel.json_body)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 0048bea54a..0e96697f9b 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -92,7 +92,6 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
{},
access_token=self.tok,
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
callback.assert_called_once()
@@ -111,7 +110,6 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
{},
access_token=self.tok,
)
- self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
def test_cannot_modify_event(self):
@@ -131,7 +129,6 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
{"x": "x"},
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.result["code"], b"500", channel.result)
def test_modify_event(self):
@@ -151,7 +148,6 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
{"x": "x"},
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.result["code"], b"200", channel.result)
event_id = channel.json_body["event_id"]
@@ -161,7 +157,6 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.result["code"], b"200", channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py
index ea5a7f3739..7a2c653df8 100644
--- a/tests/rest/client/v1/test_directory.py
+++ b/tests/rest/client/v1/test_directory.py
@@ -94,7 +94,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
- self.render(request)
self.assertEqual(channel.code, 400, channel.result)
def test_room_creation(self):
@@ -108,7 +107,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
def set_alias_via_state_event(self, expected_code, alias_length=5):
@@ -123,7 +121,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
def set_alias_via_directory(self, expected_code, alias_length=5):
@@ -134,7 +131,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
def random_alias(self, length):
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 3397ba5579..12a93f5687 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -66,14 +66,12 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/events?access_token=%s" % ("invalid" + self.token,)
)
- self.render(request)
self.assertEquals(channel.code, 401, msg=channel.result)
# valid token, expect content
request, channel = self.make_request(
"GET", "/events?access_token=%s&timeout=0" % (self.token,)
)
- self.render(request)
self.assertEquals(channel.code, 200, msg=channel.result)
self.assertTrue("chunk" in channel.json_body)
self.assertTrue("start" in channel.json_body)
@@ -92,7 +90,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/events?access_token=%s&timeout=0" % (self.token,)
)
- self.render(request)
self.assertEquals(channel.code, 200, msg=channel.result)
# We may get a presence event for ourselves down
@@ -155,5 +152,4 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/events/" + event_id, access_token=self.token,
)
- self.render(request)
self.assertEquals(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 5d987a30c7..ee75f4076f 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -64,7 +64,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"password": "monkey",
}
request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -84,7 +83,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"password": "monkey",
}
request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -111,7 +109,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"password": "monkey",
}
request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -131,7 +128,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"password": "monkey",
}
request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -158,7 +154,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"password": "notamonkey",
}
request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -178,7 +173,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"password": "notamonkey",
}
request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
@@ -188,7 +182,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we shouldn't be able to make requests without an access token
request, channel = self.make_request(b"GET", TEST_URL)
- self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN")
@@ -199,7 +192,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"password": "monkey",
}
request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
self.assertEquals(channel.code, 200, channel.result)
access_token = channel.json_body["access_token"]
@@ -209,7 +201,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
- self.render(request)
self.assertEquals(channel.code, 200, channel.result)
# time passes
@@ -219,7 +210,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
- self.render(request)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
@@ -236,7 +226,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
- self.render(request)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
@@ -247,7 +236,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
- self.render(request)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], False)
@@ -257,7 +245,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
b"DELETE", "devices/" + device_id, access_token=access_token
)
- self.render(request)
self.assertEquals(channel.code, 401, channel.result)
# check it's a UI-Auth fail
self.assertEqual(
@@ -281,7 +268,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
access_token=access_token,
content={"auth": auth},
)
- self.render(request)
self.assertEquals(channel.code, 200, channel.result)
@override_config({"session_lifetime": "24h"})
@@ -295,7 +281,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
- self.render(request)
self.assertEquals(channel.code, 200, channel.result)
# time passes
@@ -305,7 +290,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
- self.render(request)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
@@ -314,7 +298,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
b"POST", "/logout", access_token=access_token
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
@override_config({"session_lifetime": "24h"})
@@ -328,7 +311,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
- self.render(request)
self.assertEquals(channel.code, 200, channel.result)
# time passes
@@ -338,7 +320,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
- self.render(request)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
@@ -347,7 +328,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
b"POST", "/logout/all", access_token=access_token
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -423,7 +403,6 @@ class CASTestCase(unittest.HomeserverTestCase):
# Get Synapse to call the fake CAS and serve the template.
request, channel = self.make_request("GET", cas_ticket_url)
- self.render(request)
# Test that the response is HTML.
self.assertEqual(channel.code, 200)
@@ -468,7 +447,6 @@ class CASTestCase(unittest.HomeserverTestCase):
# Get Synapse to call the fake CAS and serve the template.
request, channel = self.make_request("GET", cas_ticket_url)
- self.render(request)
self.assertEqual(channel.code, 302)
location_headers = channel.headers.getRawHeaders("Location")
@@ -495,7 +473,6 @@ class CASTestCase(unittest.HomeserverTestCase):
# Get Synapse to call the fake CAS and serve the template.
request, channel = self.make_request("GET", cas_ticket_url)
- self.render(request)
# Because the user is deactivated they are served an error template.
self.assertEqual(channel.code, 403)
@@ -518,15 +495,18 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.hs.config.jwt_algorithm = self.jwt_algorithm
return self.hs
- def jwt_encode(self, token, secret=jwt_secret):
- return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii")
+ def jwt_encode(self, token: str, secret: str = jwt_secret) -> str:
+ # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
+ result = jwt.encode(token, secret, self.jwt_algorithm)
+ if isinstance(result, bytes):
+ return result.decode("ascii")
+ return result
def jwt_login(self, *args):
params = json.dumps(
{"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
return channel
def test_login_jwt_valid_registered(self):
@@ -659,7 +639,6 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_no_token(self):
params = json.dumps({"type": "org.matrix.login.jwt"})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@@ -725,15 +704,18 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
self.hs.config.jwt_algorithm = "RS256"
return self.hs
- def jwt_encode(self, token, secret=jwt_privatekey):
- return jwt.encode(token, secret, "RS256").decode("ascii")
+ def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str:
+ # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
+ result = jwt.encode(token, secret, "RS256")
+ if isinstance(result, bytes):
+ return result.decode("ascii")
+ return result
def jwt_login(self, *args):
params = json.dumps(
{"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
return channel
def test_login_jwt_valid(self):
@@ -766,7 +748,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/register?access_token=%s" % (self.service.token,),
{"username": username},
)
- self.render(request)
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
@@ -815,7 +796,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_login_appservice_user_bot(self):
@@ -831,7 +811,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_login_appservice_wrong_user(self):
@@ -847,7 +826,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
def test_login_appservice_wrong_as(self):
@@ -863,7 +841,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.another_service.token
)
- self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
def test_login_appservice_no_token(self):
@@ -878,5 +855,4 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
}
request, channel = self.make_request(b"POST", LOGIN_URL, params)
- self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 3c66255dac..5d5c24d01c 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -33,13 +33,16 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
+ presence_handler = Mock()
+ presence_handler.set_state.return_value = defer.succeed(None)
+
hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock()
+ "red",
+ http_client=None,
+ federation_client=Mock(),
+ presence_handler=presence_handler,
)
- hs.presence_handler = Mock()
- hs.presence_handler.set_state.return_value = defer.succeed(None)
-
return hs
def test_put_presence(self):
@@ -53,10 +56,9 @@ class PresenceTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
- self.render(request)
self.assertEqual(channel.code, 200)
- self.assertEqual(self.hs.presence_handler.set_state.call_count, 1)
+ self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
def test_put_presence_disabled(self):
"""
@@ -69,7 +71,6 @@ class PresenceTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", "/presence/%s/status" % (self.user_id,), body
)
- self.render(request)
self.assertEqual(channel.code, 200)
- self.assertEqual(self.hs.presence_handler.set_state.call_count, 0)
+ self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index ace0a3c08d..383a9eafac 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -195,7 +195,6 @@ class ProfileTestCase(unittest.HomeserverTestCase):
content=json.dumps({"displayname": "test"}),
access_token=self.owner_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
res = self.get_displayname()
@@ -209,7 +208,6 @@ class ProfileTestCase(unittest.HomeserverTestCase):
content=json.dumps({"displayname": "test" * 100}),
access_token=self.owner_tok,
)
- self.render(request)
self.assertEqual(channel.code, 400, channel.result)
res = self.get_displayname()
@@ -219,7 +217,6 @@ class ProfileTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/profile/%s/displayname" % (self.owner,)
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["displayname"]
@@ -284,7 +281,6 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.profile_url + url_suffix, access_token=access_token
)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
def ensure_requester_left_room(self):
@@ -327,7 +323,6 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/profile/" + self.requester, access_token=self.requester_tok
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
request, channel = self.make_request(
@@ -335,7 +330,6 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
"/profile/" + self.requester + "/displayname",
access_token=self.requester_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
request, channel = self.make_request(
@@ -343,5 +337,4 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
"/profile/" + self.requester + "/avatar_url",
access_token=self.requester_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py
index 081052f6a6..7add5523c8 100644
--- a/tests/rest/client/v1/test_push_rule_attrs.py
+++ b/tests/rest/client/v1/test_push_rule_attrs.py
@@ -48,14 +48,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
request, channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# GET enabled for that new rule
request, channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], True)
@@ -79,7 +77,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
request, channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# disable the rule
@@ -89,14 +86,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
{"enabled": False},
access_token=token,
)
- self.render(request)
self.assertEqual(channel.code, 200)
# check rule disabled
request, channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], False)
@@ -104,21 +99,18 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
request, channel = self.make_request(
"DELETE", "/pushrules/global/override/best.friend", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# PUT a new rule
request, channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# GET enabled for that new rule
request, channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], True)
@@ -141,7 +133,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
request, channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# disable the rule
@@ -151,14 +142,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
{"enabled": False},
access_token=token,
)
- self.render(request)
self.assertEqual(channel.code, 200)
# check rule disabled
request, channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], False)
@@ -169,14 +158,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
{"enabled": True},
access_token=token,
)
- self.render(request)
self.assertEqual(channel.code, 200)
# check rule enabled
request, channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["enabled"], True)
@@ -198,7 +185,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
request, channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -206,28 +192,24 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
request, channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# GET enabled for that new rule
request, channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# DELETE the rule
request, channel = self.make_request(
"DELETE", "/pushrules/global/override/best.friend", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# check 404 for deleted rule
request, channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -242,7 +224,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
request, channel = self.make_request(
"GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -260,7 +241,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
{"enabled": True},
access_token=token,
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -278,7 +258,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
{"enabled": True},
access_token=token,
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -300,14 +279,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
request, channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# GET actions for that new rule
request, channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/actions", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}]
@@ -331,7 +308,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
request, channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# change the rule actions
@@ -341,14 +317,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
{"actions": ["dont_notify"]},
access_token=token,
)
- self.render(request)
self.assertEqual(channel.code, 200)
# GET actions for that new rule
request, channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/actions", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["actions"], ["dont_notify"])
@@ -370,7 +344,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
request, channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -378,21 +351,18 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
request, channel = self.make_request(
"PUT", "/pushrules/global/override/best.friend", body, access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# DELETE the rule
request, channel = self.make_request(
"DELETE", "/pushrules/global/override/best.friend", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# check 404 for deleted rule
request, channel = self.make_request(
"GET", "/pushrules/global/override/best.friend/enabled", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -407,7 +377,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
request, channel = self.make_request(
"GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -425,7 +394,6 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
{"actions": ["dont_notify"]},
access_token=token,
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -443,6 +411,5 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
{"actions": ["dont_notify"]},
access_token=token,
)
- self.render(request)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 9ba5f9d943..49f1073c88 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -86,7 +86,6 @@ class RoomPermissionsTestCase(RoomBase):
request, channel = self.make_request(
"PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
# set topic for public room
@@ -95,7 +94,6 @@ class RoomPermissionsTestCase(RoomBase):
("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"),
b'{"topic":"Public Room Topic"}',
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
# auth as user_id now
@@ -118,12 +116,10 @@ class RoomPermissionsTestCase(RoomBase):
"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
msg_content,
)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room not joined (no state), expect 403
request, channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room and invited, expect 403
@@ -131,19 +127,16 @@ class RoomPermissionsTestCase(RoomBase):
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
request, channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id)
request, channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# send message in created room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
request, channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_topic_perms(self):
@@ -154,20 +147,16 @@ class RoomPermissionsTestCase(RoomBase):
request, channel = self.make_request(
"PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
request, channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room not joined, expect 403
request, channel = self.make_request("PUT", topic_path, topic_content)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
request, channel = self.make_request("GET", topic_path)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set topic in created PRIVATE room and invited, expect 403
@@ -175,12 +164,10 @@ class RoomPermissionsTestCase(RoomBase):
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
request, channel = self.make_request("PUT", topic_path, topic_content)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# get topic in created PRIVATE room and invited, expect 403
request, channel = self.make_request("GET", topic_path)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room and joined, expect 200
@@ -189,29 +176,24 @@ class RoomPermissionsTestCase(RoomBase):
# Only room ops can set topic by default
self.helper.auth_user_id = self.rmcreator_id
request, channel = self.make_request("PUT", topic_path, topic_content)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.helper.auth_user_id = self.user_id
request, channel = self.make_request("GET", topic_path)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body)
# set/get topic in created PRIVATE room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
request, channel = self.make_request("PUT", topic_path, topic_content)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
request, channel = self.make_request("GET", topic_path)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# get topic in PUBLIC room, not joined, expect 403
request, channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid
)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
# set topic in PUBLIC room, not joined, expect 403
@@ -220,14 +202,12 @@ class RoomPermissionsTestCase(RoomBase):
"/rooms/%s/state/m.room.topic" % self.created_public_rmid,
topic_content,
)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
def _test_get_membership(self, room=None, members=[], expect_code=None):
for member in members:
path = "/rooms/%s/state/m.room.member/%s" % (room, member)
request, channel = self.make_request("GET", path)
- self.render(request)
self.assertEquals(expect_code, channel.code)
def test_membership_basic_room_perms(self):
@@ -400,18 +380,15 @@ class RoomsMemberListTestCase(RoomBase):
def test_get_member_list(self):
room_id = self.helper.create_room_as(self.user_id)
request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
def test_get_member_list_no_room(self):
request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission(self):
room_id = self.helper.create_room_as("@some_other_guy:red")
request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_get_member_list_mixed_memberships(self):
@@ -421,19 +398,16 @@ class RoomsMemberListTestCase(RoomBase):
self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
# can't see list if you're just invited.
request, channel = self.make_request("GET", room_path)
- self.render(request)
self.assertEquals(403, channel.code, msg=channel.result["body"])
self.helper.join(room=room_id, user=self.user_id)
# can see list now joined
request, channel = self.make_request("GET", room_path)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.helper.leave(room=room_id, user=self.user_id)
# can see old list once left
request, channel = self.make_request("GET", room_path)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
@@ -446,7 +420,6 @@ class RoomsCreateTestCase(RoomBase):
# POST with no config keys, expect new room id
request, channel = self.make_request("POST", "/createRoom", "{}")
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
@@ -455,7 +428,6 @@ class RoomsCreateTestCase(RoomBase):
request, channel = self.make_request(
"POST", "/createRoom", b'{"visibility":"private"}'
)
- self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
@@ -464,7 +436,6 @@ class RoomsCreateTestCase(RoomBase):
request, channel = self.make_request(
"POST", "/createRoom", b'{"custom":"stuff"}'
)
- self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
@@ -473,18 +444,15 @@ class RoomsCreateTestCase(RoomBase):
request, channel = self.make_request(
"POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
)
- self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_invalid_content(self):
# POST with invalid content / paths, expect 400
request, channel = self.make_request("POST", "/createRoom", b'{"visibili')
- self.render(request)
self.assertEquals(400, channel.code)
request, channel = self.make_request("POST", "/createRoom", b'["hello"]')
- self.render(request)
self.assertEquals(400, channel.code)
def test_post_room_invitees_invalid_mxid(self):
@@ -493,7 +461,6 @@ class RoomsCreateTestCase(RoomBase):
request, channel = self.make_request(
"POST", "/createRoom", b'{"invite":["@alice:example.com "]}'
)
- self.render(request)
self.assertEquals(400, channel.code)
@@ -510,52 +477,42 @@ class RoomTopicTestCase(RoomBase):
def test_invalid_puts(self):
# missing keys or invalid json
request, channel = self.make_request("PUT", self.path, "{}")
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = self.make_request("PUT", self.path, '{"nao')
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = self.make_request(
"PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]'
)
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = self.make_request("PUT", self.path, "text only")
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = self.make_request("PUT", self.path, "")
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
# valid key, wrong type
content = '{"topic":["Topic name"]}'
request, channel = self.make_request("PUT", self.path, content)
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_topic(self):
# nothing should be there
request, channel = self.make_request("GET", self.path)
- self.render(request)
self.assertEquals(404, channel.code, msg=channel.result["body"])
# valid put
content = '{"topic":"Topic name"}'
request, channel = self.make_request("PUT", self.path, content)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# valid get
request, channel = self.make_request("GET", self.path)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
@@ -563,12 +520,10 @@ class RoomTopicTestCase(RoomBase):
# valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}'
request, channel = self.make_request("PUT", self.path, content)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# valid get
request, channel = self.make_request("GET", self.path)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
@@ -585,29 +540,23 @@ class RoomMemberStateTestCase(RoomBase):
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json
request, channel = self.make_request("PUT", path, "{}")
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = self.make_request("PUT", path, '{"_name":"bo"}')
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = self.make_request("PUT", path, '{"nao')
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = self.make_request(
"PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
)
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = self.make_request("PUT", path, "text only")
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = self.make_request("PUT", path, "")
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
# valid keys, wrong types
@@ -617,7 +566,6 @@ class RoomMemberStateTestCase(RoomBase):
Membership.LEAVE,
)
request, channel = self.make_request("PUT", path, content.encode("ascii"))
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_members_self(self):
@@ -629,11 +577,9 @@ class RoomMemberStateTestCase(RoomBase):
# valid join message (NOOP since we made the room)
content = '{"membership":"%s"}' % Membership.JOIN
request, channel = self.make_request("PUT", path, content.encode("ascii"))
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
request, channel = self.make_request("GET", path, None)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
expected_response = {"membership": Membership.JOIN}
@@ -649,11 +595,9 @@ class RoomMemberStateTestCase(RoomBase):
# valid invite message
content = '{"membership":"%s"}' % Membership.INVITE
request, channel = self.make_request("PUT", path, content)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
request, channel = self.make_request("GET", path, None)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assertEquals(json.loads(content), channel.json_body)
@@ -670,11 +614,9 @@ class RoomMemberStateTestCase(RoomBase):
"Join us!",
)
request, channel = self.make_request("PUT", path, content)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
request, channel = self.make_request("GET", path, None)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assertEquals(json.loads(content), channel.json_body)
@@ -725,7 +667,6 @@ class RoomJoinRatelimitTestCase(RoomBase):
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
- self.render(request)
self.assertEquals(channel.code, 200, channel.json_body)
# Check that all the rooms have been sent a profile update into.
@@ -736,7 +677,6 @@ class RoomJoinRatelimitTestCase(RoomBase):
)
request, channel = self.make_request("GET", path)
- self.render(request)
self.assertEquals(channel.code, 200)
self.assertIn("displayname", channel.json_body)
@@ -761,7 +701,6 @@ class RoomJoinRatelimitTestCase(RoomBase):
# if all of these requests ended up joining the user to a room.
for i in range(4):
request, channel = self.make_request("POST", path % room_id, {})
- self.render(request)
self.assertEquals(channel.code, 200)
@@ -777,29 +716,23 @@ class RoomMessagesTestCase(RoomBase):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json
request, channel = self.make_request("PUT", path, b"{}")
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = self.make_request("PUT", path, b'{"_name":"bo"}')
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = self.make_request("PUT", path, b'{"nao')
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = self.make_request(
"PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
)
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = self.make_request("PUT", path, b"text only")
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
request, channel = self.make_request("PUT", path, b"")
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_messages_sent(self):
@@ -807,20 +740,17 @@ class RoomMessagesTestCase(RoomBase):
content = b'{"body":"test","msgtype":{"type":"a"}}'
request, channel = self.make_request("PUT", path, content)
- self.render(request)
self.assertEquals(400, channel.code, msg=channel.result["body"])
# custom message types
content = b'{"body":"test","msgtype":"test.custom.text"}'
request, channel = self.make_request("PUT", path, content)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
# m.text message type
path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
content = b'{"body":"test2","msgtype":"m.text"}'
request, channel = self.make_request("PUT", path, content)
- self.render(request)
self.assertEquals(200, channel.code, msg=channel.result["body"])
@@ -837,7 +767,6 @@ class RoomInitialSyncTestCase(RoomBase):
request, channel = self.make_request(
"GET", "/rooms/%s/initialSync" % self.room_id
)
- self.render(request)
self.assertEquals(200, channel.code)
self.assertEquals(self.room_id, channel.json_body["room_id"])
@@ -881,7 +810,6 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
- self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("start" in channel.json_body)
self.assertEquals(token, channel.json_body["start"])
@@ -893,7 +821,6 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
- self.render(request)
self.assertEquals(200, channel.code)
self.assertTrue("start" in channel.json_body)
self.assertEquals(token, channel.json_body["start"])
@@ -933,7 +860,6 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
@@ -962,7 +888,6 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
@@ -980,7 +905,6 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
chunk = channel.json_body["chunk"]
@@ -1040,7 +964,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
}
},
)
- self.render(request)
# Check we get the results we expect -- one search result, of the sent
# messages
@@ -1074,7 +997,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
}
},
)
- self.render(request)
# Check we get the results we expect -- one search result, of the sent
# messages
@@ -1111,7 +1033,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
def test_restricted_no_auth(self):
request, channel = self.make_request("GET", self.url)
- self.render(request)
self.assertEqual(channel.code, 401, channel.result)
def test_restricted_auth(self):
@@ -1119,7 +1040,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
tok = self.login("user", "pass")
request, channel = self.make_request("GET", self.url, access_token=tok)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
@@ -1153,7 +1073,6 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
request_data,
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
@@ -1168,7 +1087,6 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
request_data,
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"]
@@ -1177,7 +1095,6 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
res_displayname = channel.json_body["content"]["displayname"]
@@ -1212,7 +1129,6 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason},
access_token=self.second_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self._check_for_reason(reason)
@@ -1227,7 +1143,6 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason},
access_token=self.second_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self._check_for_reason(reason)
@@ -1242,7 +1157,6 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.second_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self._check_for_reason(reason)
@@ -1257,7 +1171,6 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self._check_for_reason(reason)
@@ -1270,7 +1183,6 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self._check_for_reason(reason)
@@ -1283,7 +1195,6 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self._check_for_reason(reason)
@@ -1303,7 +1214,6 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason},
access_token=self.second_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self._check_for_reason(reason)
@@ -1316,7 +1226,6 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
),
access_token=self.creator_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
event_content = channel.json_body
@@ -1365,7 +1274,6 @@ class LabelsTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id, json.dumps(self.FILTER_LABELS)),
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
events_before = channel.json_body["events_before"]
@@ -1396,7 +1304,6 @@ class LabelsTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)),
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
events_before = channel.json_body["events_before"]
@@ -1432,7 +1339,6 @@ class LabelsTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)),
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
events_before = channel.json_body["events_before"]
@@ -1460,7 +1366,6 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
% (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)),
)
- self.render(request)
events = channel.json_body["chunk"]
@@ -1478,7 +1383,6 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
% (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)),
)
- self.render(request)
events = channel.json_body["chunk"]
@@ -1507,7 +1411,6 @@ class LabelsTestCase(unittest.HomeserverTestCase):
json.dumps(self.FILTER_LABELS_NOT_LABELS),
),
)
- self.render(request)
events = channel.json_body["chunk"]
@@ -1532,7 +1435,6 @@ class LabelsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", "/search?access_token=%s" % self.tok, request_data
)
- self.render(request)
results = channel.json_body["search_categories"]["room_events"]["results"]
@@ -1568,7 +1470,6 @@ class LabelsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", "/search?access_token=%s" % self.tok, request_data
)
- self.render(request)
results = channel.json_body["search_categories"]["room_events"]["results"]
@@ -1616,7 +1517,6 @@ class LabelsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", "/search?access_token=%s" % self.tok, request_data
)
- self.render(request)
results = channel.json_body["search_categories"]["room_events"]["results"]
@@ -1741,7 +1641,6 @@ class ContextTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id),
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
events_before = channel.json_body["events_before"]
@@ -1806,7 +1705,6 @@ class ContextTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id),
access_token=invited_tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
events_before = channel.json_body["events_before"]
@@ -1897,7 +1795,6 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
% (self.room_id,),
access_token=access_token,
)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
res = channel.json_body
self.assertIsInstance(res, dict)
@@ -1916,7 +1813,6 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
@@ -1947,7 +1843,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict:
@@ -1957,7 +1852,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
access_token=self.room_owner_tok,
)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
res = channel.json_body
self.assertIsInstance(res, dict)
@@ -1971,7 +1865,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
json.dumps(content),
access_token=self.room_owner_tok,
)
- self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
res = channel.json_body
self.assertIsInstance(res, dict)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index cd58ee7792..bbd30f594b 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -99,7 +99,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": true, "timeout": 30000}',
)
- self.render(request)
self.assertEquals(200, channel.code)
self.assertEquals(self.event_source.get_current_key(), 1)
@@ -123,7 +122,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": false}',
)
- self.render(request)
self.assertEquals(200, channel.code)
def test_typing_timeout(self):
@@ -132,7 +130,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": true, "timeout": 30000}',
)
- self.render(request)
self.assertEquals(200, channel.code)
self.assertEquals(self.event_source.get_current_key(), 1)
@@ -146,7 +143,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
b'{"typing": true, "timeout": 30000}',
)
- self.render(request)
self.assertEquals(200, channel.code)
self.assertEquals(self.event_source.get_current_key(), 3)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index afaf9f7b85..737c38c396 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -23,10 +23,11 @@ from typing import Any, Dict, Optional
import attr
from twisted.web.resource import Resource
+from twisted.web.server import Site
from synapse.api.constants import Membership
-from tests.server import make_request, render
+from tests.server import FakeSite, make_request
@attr.s
@@ -36,25 +37,51 @@ class RestHelper:
"""
hs = attr.ib()
- resource = attr.ib()
+ site = attr.ib(type=Site)
auth_user_id = attr.ib()
def create_room_as(
- self, room_creator=None, is_public=True, tok=None, expect_code=200,
- ):
+ self,
+ room_creator: str = None,
+ is_public: bool = True,
+ room_version: str = None,
+ tok: str = None,
+ expect_code: int = 200,
+ ) -> str:
+ """
+ Create a room.
+
+ Args:
+ room_creator: The user ID to create the room with.
+ is_public: If True, the `visibility` parameter will be set to the
+ default (public). Otherwise, the `visibility` parameter will be set
+ to "private".
+ room_version: The room version to create the room as. Defaults to Synapse's
+ default room version.
+ tok: The access token to use in the request.
+ expect_code: The expected HTTP response code.
+
+ Returns:
+ The ID of the newly created room.
+ """
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
content = {}
if not is_public:
content["visibility"] = "private"
+ if room_version:
+ content["room_version"] = room_version
if tok:
path = path + "?access_token=%s" % tok
- request, channel = make_request(
- self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8")
+ _, channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "POST",
+ path,
+ json.dumps(content).encode("utf8"),
)
- render(request, self.resource, self.hs.get_reactor())
assert channel.result["code"] == b"%d" % expect_code, channel.result
self.auth_user_id = temp_id
@@ -124,12 +151,14 @@ class RestHelper:
data = {"membership": membership}
data.update(extra_data)
- request, channel = make_request(
- self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8")
+ _, channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "PUT",
+ path,
+ json.dumps(data).encode("utf8"),
)
- render(request, self.resource, self.hs.get_reactor())
-
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
% (expect_code, int(channel.result["code"]), channel.result["body"])
@@ -157,10 +186,13 @@ class RestHelper:
if tok:
path = path + "?access_token=%s" % tok
- request, channel = make_request(
- self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8")
+ _, channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "PUT",
+ path,
+ json.dumps(content).encode("utf8"),
)
- render(request, self.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
@@ -210,9 +242,9 @@ class RestHelper:
if body is not None:
content = json.dumps(body).encode("utf8")
- request, channel = make_request(self.hs.get_reactor(), method, path, content)
-
- render(request, self.resource, self.hs.get_reactor())
+ _, channel = make_request(
+ self.hs.get_reactor(), self.site, method, path, content
+ )
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
@@ -295,14 +327,15 @@ class RestHelper:
"""
image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
- request, channel = make_request(
- self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok
- )
- request.requestHeaders.addRawHeader(
- b"Content-Length", str(image_length).encode("UTF-8")
+ _, channel = make_request(
+ self.hs.get_reactor(),
+ FakeSite(resource),
+ "POST",
+ path,
+ content=image_data,
+ access_token=tok,
+ custom_headers=[(b"Content-Length", str(image_length))],
)
- request.render(resource)
- self.hs.get_reactor().pump([100])
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 66ac4dbe85..2ac1ecb7d3 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import account, register
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from tests import unittest
+from tests.server import FakeSite, make_request
from tests.unittest import override_config
@@ -245,7 +246,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
b"account/password/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
return channel.json_body["sid"]
@@ -255,9 +255,14 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
path = link.replace("https://example.com", "")
# Load the password reset confirmation page
- request, channel = self.make_request("GET", path, shorthand=False)
- request.render(self.submit_token_resource)
- self.pump()
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(self.submit_token_resource),
+ "GET",
+ path,
+ shorthand=False,
+ )
+
self.assertEquals(200, channel.code, channel.result)
# Now POST to the same endpoint, mimicking the same behaviour as clicking the
@@ -271,15 +276,15 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
form_args.append(arg)
# Confirm the password reset
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(self.submit_token_resource),
"POST",
path,
content=urlencode(form_args).encode("utf8"),
shorthand=False,
content_is_form=True,
)
- request.render(self.submit_token_resource)
- self.pump()
self.assertEquals(200, channel.code, channel.result)
def _get_link_from_email(self):
@@ -319,7 +324,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
},
},
)
- self.render(request)
self.assertEquals(expected_code, channel.code, channel.result)
@@ -349,7 +353,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
# Check that this access token has been invalidated.
request, channel = self.make_request("GET", "account/whoami")
- self.render(request)
self.assertEqual(request.code, 401)
def test_pending_invites(self):
@@ -407,7 +410,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
- self.render(request)
self.assertEqual(request.code, 200)
@@ -542,7 +544,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -550,7 +551,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
@@ -575,14 +575,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": self.email},
access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Get user
request, channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
@@ -609,7 +607,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": self.email},
access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -618,7 +615,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -647,7 +643,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
@@ -655,7 +650,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
@@ -682,7 +676,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
@@ -690,7 +683,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
@@ -795,7 +787,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", b"account/3pid/email/requestToken", body,
)
- self.render(request)
self.assertEquals(expect_code, channel.code, channel.result)
return channel.json_body.get("sid")
@@ -808,7 +799,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
b"account/3pid/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
)
- self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(expected_errcode, channel.json_body["errcode"])
self.assertEqual(expected_error, channel.json_body["error"])
@@ -818,7 +808,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
path = link.replace("https://example.com", "")
request, channel = self.make_request("GET", path, shorthand=False)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
def _get_link_from_email(self):
@@ -867,14 +856,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Get user
request, channel = self.make_request(
"GET", self.url_3pid, access_token=self.user_id_tok,
)
- self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 86184f0d2e..77246e478f 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -38,11 +38,6 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
return succeed(True)
-class DummyPasswordChecker(UserInteractiveAuthChecker):
- def check_auth(self, authdict, clientip):
- return succeed(authdict["identifier"]["user"])
-
-
class FallbackAuthTests(unittest.HomeserverTestCase):
servlets = [
@@ -72,7 +67,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", "register", body
) # type: SynapseRequest, FakeChannel
- self.render(request)
self.assertEqual(request.code, expected_response)
return channel
@@ -87,7 +81,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "auth/m.login.recaptcha/fallback/web?session=" + session
) # type: SynapseRequest, FakeChannel
- self.render(request)
self.assertEqual(request.code, 200)
request, channel = self.make_request(
@@ -96,7 +89,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
+ post_session
+ "&g-recaptcha-response=a",
)
- self.render(request)
self.assertEqual(request.code, expected_post_response)
# The recaptcha handler is called with the response given
@@ -165,9 +157,6 @@ class UIAuthTests(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- auth_handler = hs.get_auth_handler()
- auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs)
-
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
self.user_tok = self.login("test", self.user_pass)
@@ -177,7 +166,6 @@ class UIAuthTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "devices", access_token=self.user_tok,
) # type: SynapseRequest, FakeChannel
- self.render(request)
# Get the ID of the device.
self.assertEqual(request.code, 200)
@@ -190,7 +178,6 @@ class UIAuthTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"DELETE", "devices/" + device, body, access_token=self.user_tok
) # type: SynapseRequest, FakeChannel
- self.render(request)
# Ensure the response is sane.
self.assertEqual(request.code, expected_response)
@@ -204,7 +191,6 @@ class UIAuthTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", "delete_devices", body, access_token=self.user_tok,
) # type: SynapseRequest, FakeChannel
- self.render(request)
# Ensure the response is sane.
self.assertEqual(request.code, expected_response)
@@ -240,6 +226,31 @@ class UIAuthTests(unittest.HomeserverTestCase):
},
)
+ def test_grandfathered_identifier(self):
+ """Check behaviour without "identifier" dict
+
+ Synapse used to require clients to submit a "user" field for m.login.password
+ UIA - check that still works.
+ """
+
+ device_id = self.get_device_ids()[0]
+ channel = self.delete_device(device_id, 401)
+ session = channel.json_body["session"]
+
+ # Make another request providing the UI auth flow.
+ self.delete_device(
+ device_id,
+ 200,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user,
+ "password": self.user_pass,
+ "session": session,
+ },
+ },
+ )
+
def test_can_change_body(self):
"""
The client dict can be modified during the user interactive authentication session.
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index b9e01c9418..767e126875 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -37,7 +37,6 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
def test_check_auth_required(self):
request, channel = self.make_request("GET", self.url)
- self.render(request)
self.assertEqual(channel.code, 401)
@@ -46,7 +45,6 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
access_token = self.login("user", "pass")
request, channel = self.make_request("GET", self.url, access_token=access_token)
- self.render(request)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
@@ -65,7 +63,6 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
access_token = self.login(user, password)
request, channel = self.make_request("GET", self.url, access_token=access_token)
- self.render(request)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
@@ -74,7 +71,6 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertTrue(capabilities["m.change_password"]["enabled"])
self.get_success(self.store.user_set_password_hash(user, None))
request, channel = self.make_request("GET", self.url, access_token=access_token)
- self.render(request)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index de00350580..231d5aefea 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -41,7 +41,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON,
)
- self.render(request)
self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.json_body, {"filter_id": "0"})
@@ -55,7 +54,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
self.EXAMPLE_FILTER_JSON,
)
- self.render(request)
self.assertEqual(channel.result["code"], b"403")
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
@@ -68,7 +66,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON,
)
- self.render(request)
self.hs.is_mine = _is_mine
self.assertEqual(channel.result["code"], b"403")
@@ -85,7 +82,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
)
- self.render(request)
self.assertEqual(channel.result["code"], b"200")
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
@@ -94,7 +90,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
)
- self.render(request)
self.assertEqual(channel.result["code"], b"404")
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -105,7 +100,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
)
- self.render(request)
self.assertEqual(channel.result["code"], b"400")
@@ -114,6 +108,5 @@ class FilterTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
)
- self.render(request)
self.assertEqual(channel.result["code"], b"400")
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
index c57072f50c..ee86b94917 100644
--- a/tests/rest/client/v2_alpha/test_password_policy.py
+++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -73,7 +73,6 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/client/r0/password_policy"
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(
@@ -91,7 +90,6 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_too_short(self):
request_data = json.dumps({"username": "kermit", "password": "shorty"})
request, channel = self.make_request("POST", self.register_url, request_data)
- self.render(request)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -101,7 +99,6 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_no_digit(self):
request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
request, channel = self.make_request("POST", self.register_url, request_data)
- self.render(request)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -111,7 +108,6 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_no_symbol(self):
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
request, channel = self.make_request("POST", self.register_url, request_data)
- self.render(request)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -121,7 +117,6 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_no_uppercase(self):
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
request, channel = self.make_request("POST", self.register_url, request_data)
- self.render(request)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -131,7 +126,6 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_no_lowercase(self):
request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
request, channel = self.make_request("POST", self.register_url, request_data)
- self.render(request)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
@@ -141,7 +135,6 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
def test_password_compliant(self):
request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
request, channel = self.make_request("POST", self.register_url, request_data)
- self.render(request)
# Getting a 401 here means the password has passed validation and the server has
# responded with a list of registration flows.
@@ -173,7 +166,6 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
request_data,
access_token=tok,
)
- self.render(request)
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 98c3887bbf..b53b47a0cc 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -19,8 +19,12 @@ import datetime
import json
import os
+from mock import Mock
+
import pkg_resources
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
@@ -64,7 +68,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
det_data = {"user_id": user_id, "home_server": self.hs.hostname}
@@ -76,26 +79,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
def test_POST_bad_password(self):
request_data = json.dumps({"username": "kermit", "password": 666})
request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid password")
- def test_POST_bad_username(self):
- request_data = json.dumps({"username": 777, "password": "monkey"})
- request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
-
- self.assertEquals(channel.result["code"], b"400", channel.result)
- self.assertEquals(channel.json_body["error"], "Invalid username")
-
def test_POST_user_valid(self):
user_id = "@kermit:test"
device_id = "frogfone"
@@ -107,7 +100,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}
request_data = json.dumps(params)
request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
det_data = {
"user_id": user_id,
@@ -123,7 +115,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
@@ -133,7 +124,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.hs.config.allow_guest_access = True
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.render(request)
det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -143,7 +133,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.hs.config.allow_guest_access = False
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
@@ -153,7 +142,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
for i in range(0, 6):
url = self.url + b"?kind=guest"
request, channel = self.make_request(b"POST", url, b"{}")
- self.render(request)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -164,7 +152,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -179,7 +166,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}
request_data = json.dumps(params)
request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
if i == 5:
self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -190,13 +176,11 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_advertised_flows(self):
request, channel = self.make_request(b"POST", self.url, b"{}")
- self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
@@ -220,7 +204,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
def test_advertised_flows_captcha_and_terms_and_3pids(self):
request, channel = self.make_request(b"POST", self.url, b"{}")
- self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
@@ -253,7 +236,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
def test_advertised_flows_no_msisdn_email_required(self):
request, channel = self.make_request(b"POST", self.url, b"{}")
- self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
@@ -298,12 +280,52 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
b"register/email/requestToken",
{"client_secret": "foobar", "email": email, "send_attempt": 1},
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertIsNotNone(channel.json_body.get("sid"))
+class RegisterHideProfileTestCase(unittest.HomeserverTestCase):
+
+ servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
+
+ def make_homeserver(self, reactor, clock):
+
+ self.url = b"/_matrix/client/r0/register"
+
+ config = self.default_config()
+ config["enable_registration"] = True
+ config["show_users_in_user_directory"] = False
+ config["replicate_user_profiles_to"] = ["fakeserver"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_profile_hidden(self):
+ user_id = self.register_user("kermit", "monkey")
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # We expect post_json_get_json to have been called twice: once with the original
+ # profile and once with the None profile resulting from the request to hide it
+ # from the user directory.
+ self.assertEqual(post_json.call_count, 2, post_json.call_args_list)
+
+ # Get the args (and not kwargs) passed to post_json.
+ args = post_json.call_args[0]
+ # Make sure the last call was attempting to replicate profiles.
+ split_uri = args[0].split("/")
+ self.assertEqual(split_uri[len(split_uri) - 1], "replicate_profiles", args[0])
+ # Make sure the last profile update was overriding the user's profile to None.
+ self.assertEqual(args[1]["batch"][user_id], None, args[1])
+
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -313,6 +335,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
logout.register_servlets,
account_validity.register_servlets,
+ account.register_servlets,
]
def make_homeserver(self, reactor, clock):
@@ -334,14 +357,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
request, channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
request, channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
@@ -360,19 +381,17 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.register_user("admin", "adminpassword", admin=True)
admin_tok = self.login("admin", "adminpassword")
- url = "/_matrix/client/unstable/admin/account_validity/validity"
+ url = "/_synapse/admin/v1/account_validity/validity"
params = {"user_id": user_id}
request_data = json.dumps(params)
request, channel = self.make_request(
b"POST", url, request_data, access_token=admin_tok
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
request, channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_manual_expire(self):
@@ -382,7 +401,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.register_user("admin", "adminpassword", admin=True)
admin_tok = self.login("admin", "adminpassword")
- url = "/_matrix/client/unstable/admin/account_validity/validity"
+ url = "/_synapse/admin/v1/account_validity/validity"
params = {
"user_id": user_id,
"expiration_ts": 0,
@@ -392,13 +411,11 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
b"POST", url, request_data, access_token=admin_tok
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
request, channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
@@ -411,7 +428,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.register_user("admin", "adminpassword", admin=True)
admin_tok = self.login("admin", "adminpassword")
- url = "/_matrix/client/unstable/admin/account_validity/validity"
+ url = "/_synapse/admin/v1/account_validity/validity"
params = {
"user_id": user_id,
"expiration_ts": 0,
@@ -421,12 +438,10 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
b"POST", url, request_data, access_token=admin_tok
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Try to log the user out
request, channel = self.make_request(b"POST", "/logout", access_token=tok)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Log the user in again (allowed for expired accounts)
@@ -434,10 +449,155 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
# Try to log out all of the user's sessions
request, channel = self.make_request(b"POST", "/logout/all", access_token=tok)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
+class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.client.v1.profile.register_servlets,
+ synapse.rest.client.v1.room.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Set accounts to expire after a week
+ config["enable_registration"] = True
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ config["replicate_user_profiles_to"] = "test.is"
+
+ # Mock homeserver requests to an identity server
+ mock_http_client = Mock(spec=["post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_expired_user_in_directory(self):
+ """Test that an expired user is hidden in the user directory"""
+ # Create an admin user to search the user directory
+ admin_id = self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ # Ensure the admin never expires
+ url = "/_synapse/admin/v1/account_validity/validity"
+ params = {
+ "user_id": admin_id,
+ "expiration_ts": 999999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Mock the homeserver's HTTP client
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # Create a user
+ username = "kermit"
+ user_id = self.register_user(username, "monkey")
+ self.login(username, "monkey")
+ self.get_success(
+ self.hs.get_datastore().set_profile_displayname(username, "mr.kermit", 1)
+ )
+
+ # Check that a full profile for this user is replicated
+ self.assertIsNotNone(post_json.call_args, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+
+ self.assertIsNotNone(batch, batch)
+ self.assertEquals(len(batch), 1, batch)
+
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
+
+ # Expire the user
+ url = "/_synapse/admin/v1/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Wait for the background job to run which hides expired users in the directory
+ self.reactor.advance(60 * 60 * 1000)
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ self.assertIsNotNone(post_json.call_args, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+
+ self.assertIsNotNone(batch, batch)
+ self.assertEquals(len(batch), 1, batch)
+
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's None, signifying that the user should be removed from the user
+ # directory because they were expired
+ replicated_content = batch[user_id]
+ self.assertIsNone(replicated_content)
+
+ # Now renew the user, and check they get replicated again to the identity server
+ url = "/_synapse/admin/v1/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 99999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.pump(10)
+ self.reactor.advance(10)
+ self.pump()
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+ self.assertNotEquals(batch, None, batch)
+ self.assertEquals(len(batch), 1, batch)
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None, signifying that the user is back in the user
+ # directory
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
+
+
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -500,8 +660,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
(user_id, tok) = self.create_user()
- # Move 6 days forward. This should trigger a renewal email to be sent.
- self.reactor.advance(datetime.timedelta(days=6).total_seconds())
+ # Move 5 days forward. This should trigger a renewal email to be sent.
+ self.reactor.advance(datetime.timedelta(days=5).total_seconds())
self.assertEqual(len(self.email_attempts), 1)
# Retrieving the URL from the email is too much pain for now, so we
@@ -509,18 +669,35 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
request, channel = self.make_request(b"GET", url)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Check that we're getting HTML back.
- content_type = None
- for header in channel.result.get("headers", []):
- if header[0] == b"Content-Type":
- content_type = header[1]
- self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
+ content_type = channel.headers.getRawHeaders(b"Content-Type")
+ self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
# Check that the HTML we're getting is the one we expect on a successful renewal.
- expected_html = self.hs.config.account_validity.account_renewed_html_content
+ expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id))
+ expected_html = self.hs.config.account_validity_account_renewed_template.render(
+ expiration_ts=expiration_ts
+ )
+ self.assertEqual(
+ channel.result["body"], expected_html.encode("utf8"), channel.result
+ )
+
+ # Move 1 day forward. Try to renew with the same token again.
+ url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
+ request, channel = self.make_request(b"GET", url)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Check that we're getting HTML back.
+ content_type = channel.headers.getRawHeaders(b"Content-Type")
+ self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
+
+ # Check that the HTML we're getting is the one we expect when reusing a
+ # token. The account expiration date should not have changed.
+ expected_html = self.hs.config.account_validity_account_previously_renewed_template.render(
+ expiration_ts=expiration_ts
+ )
self.assertEqual(
channel.result["body"], expected_html.encode("utf8"), channel.result
)
@@ -530,7 +707,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# succeed.
self.reactor.advance(datetime.timedelta(days=3).total_seconds())
request, channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_renewal_invalid_token(self):
@@ -538,19 +714,15 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# expected, i.e. that it responds with 404 Not Found and the correct HTML.
url = "/_matrix/client/unstable/account_validity/renew?token=123"
request, channel = self.make_request(b"GET", url)
- self.render(request)
self.assertEquals(channel.result["code"], b"404", channel.result)
# Check that we're getting HTML back.
- content_type = None
- for header in channel.result.get("headers", []):
- if header[0] == b"Content-Type":
- content_type = header[1]
- self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
+ content_type = channel.headers.getRawHeaders(b"Content-Type")
+ self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
# Check that the HTML we're getting is the one we expect when using an
# invalid/unknown token.
- expected_html = self.hs.config.account_validity.invalid_token_html_content
+ expected_html = self.hs.config.account_validity_invalid_token_template.render()
self.assertEqual(
channel.result["body"], expected_html.encode("utf8"), channel.result
)
@@ -564,7 +736,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertEqual(len(self.email_attempts), 1)
@@ -587,7 +758,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
- self.render(request)
self.assertEqual(request.code, 200)
self.reactor.advance(datetime.timedelta(days=8).total_seconds())
@@ -599,7 +769,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
tok = self.login("kermit", "monkey")
# We need to manually add an email address otherwise the handler will do
# nothing.
- now = self.hs.clock.time_msec()
+ now = self.hs.get_clock().time_msec()
self.get_success(
self.store.user_add_threepid(
user_id=user_id,
@@ -617,7 +787,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# We need to manually add an email address otherwise the handler will do
# nothing.
- now = self.hs.clock.time_msec()
+ now = self.hs.get_clock().time_msec()
self.get_success(
self.store.user_add_threepid(
user_id=user_id,
@@ -641,7 +811,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
)
- self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertEqual(len(self.email_attempts), 1)
@@ -661,7 +830,12 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
config["account_validity"] = {"enabled": False}
self.hs = self.setup_test_homeserver(config=config)
- self.hs.config.account_validity.period = self.validity_period
+
+ # We need to set these directly, instead of in the homeserver config dict above.
+ # This is due to account validity-related config options not being read by
+ # Synapse when account_validity.enabled is False.
+ self.hs.get_datastore()._account_validity_period = self.validity_period
+ self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta
self.store = self.hs.get_datastore()
@@ -675,9 +849,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
"""
user_id = self.register_user("kermit_delta", "user")
- self.hs.config.account_validity.startup_job_max_delta = self.max_delta
-
- now_ms = self.hs.clock.time_msec()
+ now_ms = self.hs.get_clock().time_msec()
self.get_success(self.store._set_expiration_date_when_missing())
res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 99c9f4e928..6cd4eb6624 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -65,7 +65,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"/rooms/%s/event/%s" % (self.room, event_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assert_dict(
@@ -114,7 +113,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
% (self.room, self.parent_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
# We expect to get back a single pagination result, which is the full
@@ -160,7 +158,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
% (self.room, self.parent_id, from_token),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
@@ -219,7 +216,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
% (self.room, self.parent_id, from_token),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
@@ -296,7 +292,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
@@ -336,7 +331,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
% (self.room, self.parent_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(
@@ -369,7 +363,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token=self.user_token,
content={},
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
request, channel = self.make_request(
@@ -378,7 +371,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
% (self.room, self.parent_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(
@@ -396,7 +388,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
% (self.room, self.parent_id, RelationTypes.REPLACE),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(400, channel.code, channel.json_body)
def test_aggregation_get_event(self):
@@ -428,7 +419,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(
@@ -465,7 +455,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(channel.json_body["content"], new_body)
@@ -523,7 +512,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEquals(channel.json_body["content"], new_body)
@@ -567,7 +555,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
% (self.room, original_event_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertIn("chunk", channel.json_body)
@@ -581,7 +568,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token=self.user_token,
content="{}",
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
# Try to check for remaining m.replace relations
@@ -591,7 +577,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
% (self.room, original_event_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
# Check that no relations are returned
@@ -624,7 +609,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token=self.user_token,
content="{}",
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
# Check that aggregations returns zero
@@ -634,7 +618,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
% (self.room, original_event_id),
access_token=self.user_token,
)
- self.render(request)
self.assertEquals(200, channel.code, channel.json_body)
self.assertIn("chunk", channel.json_body)
@@ -680,7 +663,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
json.dumps(content).encode("utf-8"),
access_token=access_token,
)
- self.render(request)
return channel
def _create_user(self, localpart):
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py
index 5ae72fd008..562a9c1ba4 100644
--- a/tests/rest/client/v2_alpha/test_shared_rooms.py
+++ b/tests/rest/client/v2_alpha/test_shared_rooms.py
@@ -47,7 +47,6 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
% other_user,
access_token=token,
)
- self.render(request)
return request, channel
def test_shared_room_list_public(self):
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index a31e44c97e..31ac0fccb8 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -36,7 +36,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
def test_sync_argless(self):
request, channel = self.make_request("GET", "/sync")
- self.render(request)
self.assertEqual(channel.code, 200)
self.assertTrue(
@@ -57,7 +56,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.hs.config.use_presence = False
request, channel = self.make_request("GET", "/sync")
- self.render(request)
self.assertEqual(channel.code, 200)
self.assertTrue(
@@ -199,7 +197,6 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/sync?filter=%s" % sync_filter, access_token=tok
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"]
@@ -253,13 +250,11 @@ class SyncTypingTests(unittest.HomeserverTestCase):
typing_url % (room, other_user_id, other_access_token),
b'{"typing": true, "timeout": 30000}',
)
- self.render(request)
self.assertEquals(200, channel.code)
request, channel = self.make_request(
"GET", "/sync?access_token=%s" % (access_token,)
)
- self.render(request)
self.assertEquals(200, channel.code)
next_batch = channel.json_body["next_batch"]
@@ -269,7 +264,6 @@ class SyncTypingTests(unittest.HomeserverTestCase):
typing_url % (room, other_user_id, other_access_token),
b'{"typing": false}',
)
- self.render(request)
self.assertEquals(200, channel.code)
# Start typing.
@@ -278,14 +272,12 @@ class SyncTypingTests(unittest.HomeserverTestCase):
typing_url % (room, other_user_id, other_access_token),
b'{"typing": true, "timeout": 30000}',
)
- self.render(request)
self.assertEquals(200, channel.code)
# Should return immediately
request, channel = self.make_request(
"GET", sync_url % (access_token, next_batch)
)
- self.render(request)
self.assertEquals(200, channel.code)
next_batch = channel.json_body["next_batch"]
@@ -300,7 +292,6 @@ class SyncTypingTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", sync_url % (access_token, next_batch)
)
- self.render(request)
self.assertEquals(200, channel.code)
next_batch = channel.json_body["next_batch"]
@@ -311,7 +302,6 @@ class SyncTypingTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", sync_url % (access_token, next_batch)
)
- self.render(request)
self.assertEquals(200, channel.code)
next_batch = channel.json_body["next_batch"]
@@ -320,10 +310,8 @@ class SyncTypingTests(unittest.HomeserverTestCase):
typing._reset()
# Now it SHOULD fail as it never completes!
- request, channel = self.make_request(
- "GET", sync_url % (access_token, next_batch)
- )
- self.assertRaises(TimedOutException, self.render, request)
+ with self.assertRaises(TimedOutException):
+ self.make_request("GET", sync_url % (access_token, next_batch))
class UnreadMessagesTestCase(unittest.HomeserverTestCase):
@@ -401,7 +389,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
body,
access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
# Check that the unread counter is back to 0.
@@ -466,7 +453,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", self.url % self.next_batch, access_token=self.tok,
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 6850c666be..fbcf8d5b86 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -32,7 +32,7 @@ from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.stringutils import random_string
from tests import unittest
-from tests.server import FakeChannel, wait_until_result
+from tests.server import FakeChannel
from tests.utils import default_config
@@ -41,7 +41,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
self.http_client = Mock()
return self.setup_test_homeserver(http_client=self.http_client)
- def create_test_json_resource(self):
+ def create_test_resource(self):
return create_resource_tree(
{"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
)
@@ -94,7 +94,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
% (server_name.encode("utf-8"), key_id.encode("utf-8")),
b"1.1",
)
- wait_until_result(self.reactor, req)
+ channel.await_result()
self.assertEqual(channel.code, 200)
resp = channel.json_body
return resp
@@ -190,7 +190,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
req.requestReceived(
b"POST", path.encode("utf-8"), b"1.1",
)
- wait_until_result(self.reactor, req)
+ channel.await_result()
self.assertEqual(channel.code, 200)
resp = channel.json_body
return resp
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 5f897d49cf..2a3b2a8f27 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -36,6 +36,7 @@ from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from tests import unittest
+from tests.server import FakeSite, make_request
class MediaStorageTests(unittest.HomeserverTestCase):
@@ -227,8 +228,14 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def _req(self, content_disposition):
- request, channel = self.make_request("GET", self.media_id, shorthand=False)
- request.render(self.download_resource)
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(self.download_resource),
+ "GET",
+ self.media_id,
+ shorthand=False,
+ await_result=False,
+ )
self.pump()
# We've made one fetch, to example.com, using the media URL, and asking
@@ -317,10 +324,14 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
- request, channel = self.make_request(
- "GET", self.media_id + params, shorthand=False
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(self.thumbnail_resource),
+ "GET",
+ self.media_id + params,
+ shorthand=False,
+ await_result=False,
)
- request.render(self.thumbnail_resource)
self.pump()
headers = {
@@ -348,7 +359,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.json_body,
{
"errcode": "M_NOT_FOUND",
- "error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']"
- % method,
+ "error": "Not found [b'example.com', b'12345']",
},
)
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index c00a7b9114..ccdc8c2ecf 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -133,13 +133,18 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.reactor.nameResolver = Resolver()
+ def create_test_resource(self):
+ return self.hs.get_media_repository_resource()
+
def test_cache_returns_correct_type(self):
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -160,10 +165,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Check the cache returns the correct response
request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ "GET", "preview_url?url=http://matrix.org", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
# Check the cache response has the same content
self.assertEqual(channel.code, 200)
@@ -178,10 +181,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Check the database cache returns the correct response
request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ "GET", "preview_url?url=http://matrix.org", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
# Check the cache response has the same content
self.assertEqual(channel.code, 200)
@@ -201,9 +202,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -234,9 +237,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -267,9 +272,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -298,9 +305,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ "GET",
+ "preview_url?url=http://example.com",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -326,10 +335,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ "GET", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
# No requests made.
self.assertEqual(len(self.reactor.tcpClients), 0)
@@ -349,10 +356,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ "GET", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
self.assertEqual(channel.code, 502)
self.assertEqual(
@@ -368,10 +373,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
Blacklisted IP addresses, accessed directly, are not spidered.
"""
request, channel = self.make_request(
- "GET", "url_preview?url=http://192.168.1.1", shorthand=False
+ "GET", "preview_url?url=http://192.168.1.1", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
# No requests made.
self.assertEqual(len(self.reactor.tcpClients), 0)
@@ -389,10 +392,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
Blacklisted IP ranges, accessed directly, are not spidered.
"""
request, channel = self.make_request(
- "GET", "url_preview?url=http://1.1.1.2", shorthand=False
+ "GET", "preview_url?url=http://1.1.1.2", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
self.assertEqual(channel.code, 403)
self.assertEqual(
@@ -411,9 +412,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ "GET",
+ "preview_url?url=http://example.com",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -446,10 +449,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
]
request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ "GET", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
self.assertEqual(channel.code, 502)
self.assertEqual(
channel.json_body,
@@ -468,10 +469,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
]
request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ "GET", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
# No requests made.
self.assertEqual(len(self.reactor.tcpClients), 0)
@@ -491,10 +490,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ "GET", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
self.assertEqual(channel.code, 502)
self.assertEqual(
@@ -510,10 +507,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
OPTIONS returns the OPTIONS.
"""
request, channel = self.make_request(
- "OPTIONS", "url_preview?url=http://example.com", shorthand=False
+ "OPTIONS", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {})
@@ -525,9 +520,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Build and make a request to the server
request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ "GET",
+ "preview_url?url=http://example.com",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
# Extract Synapse's tcp client
@@ -598,10 +595,10 @@ class URLPreviewTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET",
- "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+ "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -663,10 +660,10 @@ class URLPreviewTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET",
- "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+ "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index 2d021f6565..02a46e5fda 100644
--- a/tests/rest/test_health.py
+++ b/tests/rest/test_health.py
@@ -20,15 +20,12 @@ from tests import unittest
class HealthCheckTests(unittest.HomeserverTestCase):
- def setUp(self):
- super().setUp()
-
+ def create_test_resource(self):
# replace the JsonResource with a HealthResource.
- self.resource = HealthResource()
+ return HealthResource()
def test_health(self):
request, channel = self.make_request("GET", "/health", shorthand=False)
- self.render(request)
self.assertEqual(request.code, 200)
self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index dcd65c2a50..6a930f4148 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -20,11 +20,9 @@ from tests import unittest
class WellKnownTests(unittest.HomeserverTestCase):
- def setUp(self):
- super().setUp()
-
+ def create_test_resource(self):
# replace the JsonResource with a WellKnownResource
- self.resource = WellKnownResource(self.hs)
+ return WellKnownResource(self.hs)
def test_well_known(self):
self.hs.config.public_baseurl = "https://tesths"
@@ -33,7 +31,6 @@ class WellKnownTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.render(request)
self.assertEqual(request.code, 200)
self.assertEqual(
@@ -50,6 +47,5 @@ class WellKnownTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.render(request)
self.assertEqual(request.code, 404)
diff --git a/tests/rulecheck/__init__.py b/tests/rulecheck/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rulecheck/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py
new file mode 100644
index 0000000000..0937f304cf
--- /dev/null
+++ b/tests/rulecheck/test_domainrulecheck.py
@@ -0,0 +1,333 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+
+import synapse.rest.admin
+from synapse.config._base import ConfigError
+from synapse.rest.client.v1 import login, room
+from synapse.rulecheck.domain_rule_checker import DomainRuleChecker
+
+from tests import unittest
+from tests.server import make_request
+
+
+class DomainRuleCheckerTestCase(unittest.TestCase):
+ def test_allowed(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ "domains_prevented_from_being_invited_to_published_rooms": ["target_two"],
+ }
+ check = DomainRuleChecker(config)
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_one", None, "room", False
+ )
+ )
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False
+ )
+ )
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_two", "test:target_two", None, "room", False
+ )
+ )
+
+ # User can invite internal user to a published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test1:target_one", None, "room", False, True
+ )
+ )
+
+ # User can invite external user to a non-published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False, False
+ )
+ )
+
+ def test_disallowed(self):
+ config = {
+ "default": True,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ "source_four": [],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_one", "test:target_three", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_two", "test:target_three", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_two", "test:target_one", None, "room", False
+ )
+ )
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_four", "test:target_one", None, "room", False
+ )
+ )
+
+ # User cannot invite external user to a published room
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_one", "test:target_two", None, "room", False, True
+ )
+ )
+
+ def test_default_allow(self):
+ config = {
+ "default": True,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertTrue(
+ check.user_may_invite(
+ "test:source_three", "test:target_one", None, "room", False
+ )
+ )
+
+ def test_default_deny(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ check = DomainRuleChecker(config)
+ self.assertFalse(
+ check.user_may_invite(
+ "test:source_three", "test:target_one", None, "room", False
+ )
+ )
+
+ def test_config_parse(self):
+ config = {
+ "default": False,
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ },
+ }
+ self.assertEquals(config, DomainRuleChecker.parse_config(config))
+
+ def test_config_parse_failure(self):
+ config = {
+ "domain_mapping": {
+ "source_one": ["target_one", "target_two"],
+ "source_two": ["target_two"],
+ }
+ }
+ self.assertRaises(ConfigError, DomainRuleChecker.parse_config, config)
+
+
+class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["localhost"]
+
+ config["spam_checker"] = {
+ "module": "synapse.rulecheck.domain_rule_checker.DomainRuleChecker",
+ "config": {
+ "default": True,
+ "domain_mapping": {},
+ "can_only_join_rooms_with_invite": True,
+ "can_only_create_one_to_one_rooms": True,
+ "can_only_invite_during_room_creation": True,
+ "can_invite_by_third_party_id": False,
+ },
+ }
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_user_id = self.register_user("admin_user", "pass", admin=True)
+ self.admin_access_token = self.login("admin_user", "pass")
+
+ self.normal_user_id = self.register_user("normal_user", "pass", admin=False)
+ self.normal_access_token = self.login("normal_user", "pass")
+
+ self.other_user_id = self.register_user("other_user", "pass", admin=False)
+
+ def test_admin_can_create_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ def test_normal_user_cannot_create_empty_room(self):
+ channel = self._create_room(self.normal_access_token)
+ assert channel.result["code"] == b"403", channel.result
+
+ def test_normal_user_cannot_create_room_with_multiple_invites(self):
+ channel = self._create_room(
+ self.normal_access_token,
+ content={"invite": [self.other_user_id, self.admin_user_id]},
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ # Test that it correctly counts both normal and third party invites
+ channel = self._create_room(
+ self.normal_access_token,
+ content={
+ "invite": [self.other_user_id],
+ "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+ },
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ # Test that it correctly rejects third party invites
+ channel = self._create_room(
+ self.normal_access_token,
+ content={
+ "invite": [],
+ "invite_3pid": [{"medium": "email", "address": "foo@example.com"}],
+ },
+ )
+ assert channel.result["code"] == b"403", channel.result
+
+ def test_normal_user_can_room_with_single_invites(self):
+ channel = self._create_room(
+ self.normal_access_token, content={"invite": [self.other_user_id]}
+ )
+ assert channel.result["code"] == b"200", channel.result
+
+ def test_cannot_join_public_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=403
+ )
+
+ def test_can_join_invited_room(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ def test_cannot_invite(self):
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ self.helper.invite(
+ room_id,
+ src=self.normal_user_id,
+ targ=self.other_user_id,
+ tok=self.normal_access_token,
+ expect_code=403,
+ )
+
+ def test_cannot_3pid_invite(self):
+ """Test that unbound 3pid invites get rejected.
+ """
+ channel = self._create_room(self.admin_access_token)
+ assert channel.result["code"] == b"200", channel.result
+
+ room_id = channel.json_body["room_id"]
+
+ self.helper.invite(
+ room_id,
+ src=self.admin_user_id,
+ targ=self.normal_user_id,
+ tok=self.admin_access_token,
+ )
+
+ self.helper.join(
+ room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200
+ )
+
+ self.helper.invite(
+ room_id,
+ src=self.normal_user_id,
+ targ=self.other_user_id,
+ tok=self.normal_access_token,
+ expect_code=403,
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ "rooms/%s/invite" % (room_id),
+ {"address": "foo@bar.com", "medium": "email", "id_server": "localhost"},
+ access_token=self.normal_access_token,
+ )
+ self.assertEqual(channel.code, 403, channel.result["body"])
+
+ def _create_room(self, token, content={}):
+ path = "/_matrix/client/r0/createRoom?access_token=%s" % (token,)
+
+ request, channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "POST",
+ path,
+ content=json.dumps(content).encode("utf8"),
+ )
+
+ return channel
diff --git a/tests/server.py b/tests/server.py
index 3dd2cfc072..a51ad0c14e 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -2,7 +2,7 @@ import json
import logging
from collections import deque
from io import SEEK_END, BytesIO
-from typing import Callable
+from typing import Callable, Iterable, Optional, Tuple, Union
import attr
from typing_extensions import Deque
@@ -19,8 +19,8 @@ from twisted.internet.interfaces import (
)
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
-from twisted.web.http import unquote
from twisted.web.http_headers import Headers
+from twisted.web.resource import IResource
from twisted.web.server import Site
from synapse.http.site import SynapseRequest
@@ -117,6 +117,25 @@ class FakeChannel:
def transport(self):
return self
+ def await_result(self, timeout: int = 100) -> None:
+ """
+ Wait until the request is finished.
+ """
+ self._reactor.run()
+ x = 0
+
+ while not self.result.get("done"):
+ # If there's a producer, tell it to resume producing so we get content
+ if self._producer:
+ self._producer.resumeProducing()
+
+ x += 1
+
+ if x > timeout:
+ raise TimedOutException("Timed out waiting for request to finish.")
+
+ self._reactor.advance(0.1)
+
class FakeSite:
"""
@@ -128,9 +147,21 @@ class FakeSite:
site_tag = "test"
access_logger = logging.getLogger("synapse.access.http.fake")
+ def __init__(self, resource: IResource):
+ """
+
+ Args:
+ resource: the resource to be used for rendering all requests
+ """
+ self._resource = resource
+
+ def getResourceFor(self, request):
+ return self._resource
+
def make_request(
reactor,
+ site: Site,
method,
path,
content=b"",
@@ -139,12 +170,19 @@ def make_request(
shorthand=True,
federation_auth_origin=None,
content_is_form=False,
+ await_result: bool = True,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
):
"""
- Make a web request using the given method and path, feed it the
- content, and return the Request and the Channel underneath.
+ Make a web request using the given method, path and content, and render it
+
+ Returns the Request and the Channel underneath.
Args:
+ site: The twisted Site to use to render the request
+
method (bytes/unicode): The HTTP request method ("verb").
path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
escaped UTF-8 & spaces and such).
@@ -157,6 +195,12 @@ def make_request(
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
+ custom_headers: (name, value) pairs to add as request headers
+
+ await_result: whether to wait for the request to complete rendering. If true,
+ will pump the reactor until the the renderer tells the channel the request
+ is finished.
+
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
"""
@@ -178,18 +222,17 @@ def make_request(
if not path.startswith(b"/"):
path = b"/" + path
+ if isinstance(content, dict):
+ content = json.dumps(content).encode("utf8")
if isinstance(content, str):
content = content.encode("utf8")
- site = FakeSite()
channel = FakeChannel(site, reactor)
req = request(channel)
- req.process = lambda: b""
req.content = BytesIO(content)
# Twisted expects to be at the end of the content when parsing the request.
req.content.seek(SEEK_END)
- req.postpath = list(map(unquote, path[1:].split(b"/")))
if access_token:
req.requestHeaders.addRawHeader(
@@ -211,35 +254,16 @@ def make_request(
# Assume the body is JSON
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
- req.requestReceived(method, path, b"1.1")
-
- return req, channel
-
+ if custom_headers:
+ for k, v in custom_headers:
+ req.requestHeaders.addRawHeader(k, v)
-def wait_until_result(clock, request, timeout=100):
- """
- Wait until the request is finished.
- """
- clock.run()
- x = 0
-
- while not request.finished:
-
- # If there's a producer, tell it to resume producing so we get content
- if request._channel._producer:
- request._channel._producer.resumeProducing()
-
- x += 1
-
- if x > timeout:
- raise TimedOutException("Timed out waiting for request to finish.")
-
- clock.advance(0.1)
+ req.requestReceived(method, path, b"1.1")
+ if await_result:
+ channel.await_result()
-def render(request, resource, clock):
- request.render(resource)
- wait_until_result(clock, request)
+ return req, channel
@implementer(IReactorPluggableNameResolver)
diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index 872039c8f1..e0a9cd93ac 100644
--- a/tests/server_notices/test_consent.py
+++ b/tests/server_notices/test_consent.py
@@ -73,7 +73,6 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/_matrix/client/r0/sync", access_token=self.access_token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# Get the Room ID to join
@@ -85,14 +84,12 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
"/_matrix/client/r0/rooms/" + room_id + "/join",
access_token=self.access_token,
)
- self.render(request)
self.assertEqual(channel.code, 200)
# Sync again, to get the message in the room
request, channel = self.make_request(
"GET", "/_matrix/client/r0/sync", access_token=self.access_token
)
- self.render(request)
self.assertEqual(channel.code, 200)
# Get the message
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 6382b19dc3..9c8027a5b2 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -306,7 +306,6 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
tok = self.login("user", "password")
request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
- self.render(request)
invites = channel.json_body["rooms"]["invite"]
self.assertEqual(len(invites), 0, invites)
@@ -320,7 +319,6 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
# Sync again to retrieve the events in the room, so we can check whether this
# room has a notice in it.
request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
- self.render(request)
# Scan the events in the room to search for a message from the server notices
# user.
@@ -358,7 +356,6 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET", "/sync?timeout=0", access_token=tok,
)
- self.render(request)
# Also retrieves the list of invites for this user. We don't care about that
# one except if we're processing the last user, which should have received an
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 5a1e5c4e66..c13a57dad1 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -309,36 +309,6 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
)
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
- @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
- def test_send_dummy_event_without_consent(self):
- self._create_extremity_rich_graph()
- self._enable_consent_checking()
-
- # Pump the reactor repeatedly so that the background updates have a
- # chance to run. Attempt to add dummy event with user that has not consented
- # Check that dummy event send fails.
- self.pump(10 * 60)
- latest_event_ids = self.get_success(
- self.store.get_latest_event_ids_in_room(self.room_id)
- )
- self.assertTrue(len(latest_event_ids) == self.EXTREMITIES_COUNT)
-
- # Create new user, and add consent
- user2 = self.register_user("user2", "password")
- token2 = self.login("user2", "password")
- self.get_success(
- self.store.user_set_consent_version(user2, self.CONSENT_VERSION)
- )
- self.helper.join(self.room_id, user2, tok=token2)
-
- # Background updates should now cause a dummy event to be added to the graph
- self.pump(10 * 60)
-
- latest_event_ids = self.get_success(
- self.store.get_latest_event_ids_in_room(self.room_id)
- )
- self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
-
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250)
def test_expiry_logic(self):
"""Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index e96ca1c8ca..a69117c5a9 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -21,6 +21,7 @@ from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import login
from tests import unittest
+from tests.server import make_request
from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -408,18 +409,18 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
# Advance to a known time
self.reactor.advance(123456 - self.reactor.seconds())
- request, channel = self.make_request(
+ headers1 = {b"User-Agent": b"Mozzila pizza"}
+ headers1.update(headers)
+
+ make_request(
+ self.reactor,
+ self.site,
"GET",
- "/_matrix/client/r0/admin/users/" + self.user_id,
+ "/_synapse/admin/v1/users/" + self.user_id,
access_token=access_token,
+ custom_headers=headers1.items(),
**make_request_args,
)
- request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")
-
- # Add the optional headers
- for h, v in headers.items():
- request.requestHeaders.addRawHeader(h, v)
- self.render(request)
# Advance so the save loop occurs
self.reactor.advance(100)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 7e7f1286d9..fe37d2ed5a 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -39,7 +39,7 @@ class DataStoreTestCase(unittest.TestCase):
)
yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.user.localpart, self.displayname)
+ self.store.set_profile_displayname(self.user.localpart, self.displayname, 1)
)
users, total = yield defer.ensureDeferred(
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 3fd0a38cf5..7a38022e71 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -36,7 +36,7 @@ class ProfileStoreTestCase(unittest.TestCase):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1)
)
self.assertEquals(
@@ -54,7 +54,7 @@ class ProfileStoreTestCase(unittest.TestCase):
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
- self.u_frank.localpart, "http://my.site/here"
+ self.u_frank.localpart, "http://my.site/here", 1
)
)
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 654a6fa42d..c5ec6396a7 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -202,7 +202,6 @@ class TestMauLimit(unittest.HomeserverTestCase):
)
request, channel = self.make_request("POST", "/register", request_data)
- self.render(request)
if channel.code != 200:
raise HttpResponseException(
@@ -215,7 +214,6 @@ class TestMauLimit(unittest.HomeserverTestCase):
def do_sync_for_user(self, token):
request, channel = self.make_request("GET", "/sync", access_token=token)
- self.render(request)
if channel.code != 200:
raise HttpResponseException(
diff --git a/tests/test_server.py b/tests/test_server.py
index 655c918a15..c387a85f2e 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -26,9 +26,9 @@ from synapse.util import Clock
from tests import unittest
from tests.server import (
+ FakeSite,
ThreadedMemoryReactorClock,
make_request,
- render,
setup_test_homeserver,
)
@@ -62,9 +62,8 @@ class JsonResourceTests(unittest.TestCase):
)
request, channel = make_request(
- self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
+ self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
)
- render(request, res, self.reactor)
self.assertEqual(request.args, {b"a": ["\N{SNOWMAN}".encode("utf8")]})
self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"})
@@ -83,8 +82,7 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
- render(request, res, self.reactor)
+ _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
self.assertEqual(channel.result["code"], b"500")
@@ -108,8 +106,7 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
- render(request, res, self.reactor)
+ _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
self.assertEqual(channel.result["code"], b"500")
@@ -127,8 +124,7 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
- render(request, res, self.reactor)
+ _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
self.assertEqual(channel.result["code"], b"403")
self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
@@ -150,8 +146,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar")
- render(request, res, self.reactor)
+ _, channel = make_request(
+ self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar"
+ )
self.assertEqual(channel.result["code"], b"400")
self.assertEqual(channel.json_body["error"], "Unrecognized request")
@@ -173,8 +170,7 @@ class JsonResourceTests(unittest.TestCase):
)
# The path was registered as GET, but this is a HEAD request.
- request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo")
- render(request, res, self.reactor)
+ _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo")
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
@@ -196,9 +192,6 @@ class OptionsResourceTests(unittest.TestCase):
def _make_request(self, method, path):
"""Create a request from the method/path and return a channel with the response."""
- request, channel = make_request(self.reactor, method, path, shorthand=False)
- request.prepath = [] # This doesn't get set properly by make_request.
-
# Create a site and query for the resource.
site = SynapseSite(
"test",
@@ -207,11 +200,9 @@ class OptionsResourceTests(unittest.TestCase):
self.resource,
"1.0",
)
- request.site = site
- resource = site.getResourceFor(request)
- # Finally, render the resource and return the channel.
- render(request, resource, self.reactor)
+ # render the request and return the channel
+ _, channel = make_request(self.reactor, site, method, path, shorthand=False)
return channel
def test_unknown_options_request(self):
@@ -284,8 +275,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- request, channel = make_request(self.reactor, b"GET", b"/path")
- render(request, res, self.reactor)
+ _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
self.assertEqual(channel.result["code"], b"200")
body = channel.result["body"]
@@ -303,8 +293,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- request, channel = make_request(self.reactor, b"GET", b"/path")
- render(request, res, self.reactor)
+ _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
self.assertEqual(channel.result["code"], b"301")
headers = channel.result["headers"]
@@ -325,8 +314,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- request, channel = make_request(self.reactor, b"GET", b"/path")
- render(request, res, self.reactor)
+ _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
self.assertEqual(channel.result["code"], b"304")
headers = channel.result["headers"]
@@ -345,8 +333,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- request, channel = make_request(self.reactor, b"HEAD", b"/path")
- render(request, res, self.reactor)
+ _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
diff --git a/tests/test_state.py b/tests/test_state.py
index 80b0ccbc40..6227a3ba95 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -169,6 +169,7 @@ class StateTestCase(unittest.TestCase):
"get_state_handler",
"get_clock",
"get_state_resolution_handler",
+ "hostname",
]
)
hs.config = default_config("tesths", True)
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index b89798336c..71580b454d 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -54,7 +54,6 @@ class TermsTestCase(unittest.HomeserverTestCase):
# Do a UI auth request
request_data = json.dumps({"username": "kermit", "password": "monkey"})
request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
@@ -98,7 +97,6 @@ class TermsTestCase(unittest.HomeserverTestCase):
self.registration_handler.check_username = Mock(return_value=True)
request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
# We don't bother checking that the response is correct - we'll leave that to
# other tests. We just want to make sure we're on the right path.
@@ -116,7 +114,6 @@ class TermsTestCase(unittest.HomeserverTestCase):
}
)
request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
# We're interested in getting a response that looks like a successful
# registration, not so much that the details are exactly what we want.
diff --git a/tests/test_types.py b/tests/test_types.py
index 480bea1bdc..d4a722a30f 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -12,9 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from six import string_types
from synapse.api.errors import SynapseError
-from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ GroupID,
+ RoomAlias,
+ UserID,
+ map_username_to_mxid_localpart,
+ strip_invalid_mxid_characters,
+)
from tests import unittest
@@ -103,3 +110,16 @@ class MapUsernameTestCase(unittest.TestCase):
self.assertEqual(
map_username_to_mxid_localpart("têst".encode("utf-8")), "t=c3=aast"
)
+
+
+class StripInvalidMxidCharactersTestCase(unittest.TestCase):
+ def test_return_type(self):
+ unstripped = strip_invalid_mxid_characters("test")
+ stripped = strip_invalid_mxid_characters("test@")
+
+ self.assertTrue(isinstance(unstripped, string_types), type(unstripped))
+ self.assertTrue(isinstance(stripped, string_types), type(stripped))
+
+ def test_strip(self):
+ stripped = strip_invalid_mxid_characters("test@")
+ self.assertEqual(stripped, "test", stripped)
diff --git a/tests/unittest.py b/tests/unittest.py
index e36ac89196..a9d59e31f7 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -30,6 +30,7 @@ from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
+from twisted.web.resource import Resource
from synapse.api.constants import EventTypes, Membership
from synapse.config.homeserver import HomeServerConfig
@@ -47,13 +48,7 @@ from synapse.server import HomeServer
from synapse.types import UserID, create_requester
from synapse.util.ratelimitutils import FederationRateLimiter
-from tests.server import (
- FakeChannel,
- get_clock,
- make_request,
- render,
- setup_test_homeserver,
-)
+from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
@@ -239,10 +234,8 @@ class HomeserverTestCase(TestCase):
if not isinstance(self.hs, HomeServer):
raise Exception("A homeserver wasn't returned, but %r" % (self.hs,))
- # Register the resources
- self.resource = self.create_test_json_resource()
-
- # create a site to wrap the resource.
+ # create the root resource, and a site to wrap it.
+ self.resource = self.create_test_resource()
self.site = SynapseSite(
logger_name="synapse.access.http.fake",
site_tag=self.hs.config.server.server_name,
@@ -253,7 +246,7 @@ class HomeserverTestCase(TestCase):
from tests.rest.client.v1.utils import RestHelper
- self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
+ self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
if hasattr(self, "user_id"):
if self.hijack_auth:
@@ -323,15 +316,12 @@ class HomeserverTestCase(TestCase):
hs = self.setup_test_homeserver()
return hs
- def create_test_json_resource(self):
+ def create_test_resource(self) -> Resource:
"""
- Create a test JsonResource, with the relevant servlets registerd to it
-
- The default implementation calls each function in `servlets` to do the
- registration.
+ Create a the root resource for the test server.
- Returns:
- JsonResource:
+ The default implementation creates a JsonResource and calls each function in
+ `servlets` to register servletes against it
"""
resource = JsonResource(self.hs)
@@ -381,6 +371,7 @@ class HomeserverTestCase(TestCase):
shorthand: bool = True,
federation_auth_origin: str = None,
content_is_form: bool = False,
+ await_result: bool = True,
) -> Tuple[SynapseRequest, FakeChannel]:
...
@@ -395,6 +386,7 @@ class HomeserverTestCase(TestCase):
shorthand: bool = True,
federation_auth_origin: str = None,
content_is_form: bool = False,
+ await_result: bool = True,
) -> Tuple[T, FakeChannel]:
...
@@ -408,6 +400,7 @@ class HomeserverTestCase(TestCase):
shorthand: bool = True,
federation_auth_origin: str = None,
content_is_form: bool = False,
+ await_result: bool = True,
) -> Tuple[T, FakeChannel]:
"""
Create a SynapseRequest at the path using the method and containing the
@@ -426,14 +419,16 @@ class HomeserverTestCase(TestCase):
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
+ await_result: whether to wait for the request to complete rendering. If
+ true (the default), will pump the test reactor until the the renderer
+ tells the channel the request is finished.
+
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
"""
- if isinstance(content, dict):
- content = json.dumps(content).encode("utf8")
-
return make_request(
self.reactor,
+ self.site,
method,
path,
content,
@@ -442,18 +437,9 @@ class HomeserverTestCase(TestCase):
shorthand,
federation_auth_origin,
content_is_form,
+ await_result,
)
- def render(self, request):
- """
- Render a request against the resources registered by the test class's
- servlets.
-
- Args:
- request (synapse.http.site.SynapseRequest): The request to render.
- """
- render(request, self.resource, self.reactor)
-
def setup_test_homeserver(self, *args, **kwargs):
"""
Set up the test homeserver, meant to be called by the overridable
@@ -568,8 +554,7 @@ class HomeserverTestCase(TestCase):
self.hs.config.registration_shared_secret = "shared"
# Create the user
- request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
- self.render(request)
+ request, channel = self.make_request("GET", "/_synapse/admin/v1/register")
self.assertEqual(channel.code, 200, msg=channel.result)
nonce = channel.json_body["nonce"]
@@ -595,9 +580,8 @@ class HomeserverTestCase(TestCase):
}
)
request, channel = self.make_request(
- "POST", "/_matrix/client/r0/admin/register", body.encode("utf8")
+ "POST", "/_synapse/admin/v1/register", body.encode("utf8")
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
user_id = channel.json_body["user_id"]
@@ -616,7 +600,6 @@ class HomeserverTestCase(TestCase):
request, channel = self.make_request(
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
access_token = channel.json_body["access_token"]
@@ -685,7 +668,6 @@ class HomeserverTestCase(TestCase):
request, channel = self.make_request(
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
)
- self.render(request)
self.assertEqual(channel.code, 403, channel.result)
def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
diff --git a/tests/utils.py b/tests/utils.py
index acec74e9e9..1584eacb12 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -176,6 +176,8 @@ def default_config(name, parse=False):
"update_user_directory": False,
"caches": {"global_factor": 1},
"listeners": [{"port": 0, "type": "http"}],
+ # Enable encryption by default in private rooms
+ "encryption_enabled_by_default_for_room_type": "invite",
}
if parse:
@@ -271,7 +273,7 @@ def setup_test_homeserver(
# Install @cache_in_self attributes
for key, val in kwargs.items():
- setattr(hs, key, val)
+ setattr(hs, "_" + key, val)
# Mock TLS
hs.tls_server_context_factory = Mock()
|