diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py
index 913ea3c98e..5256c11fe6 100644
--- a/tests/rest/client/test_power_levels.py
+++ b/tests/rest/client/test_power_levels.py
@@ -73,7 +73,9 @@ class PowerLevelsTestCase(HomeserverTestCase):
# Mod the mod
room_power_levels = self.helper.get_state(
- self.room_id, "m.room.power_levels", tok=self.admin_access_token,
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.admin_access_token,
)
# Update existing power levels with mod at PL50
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index f0707646bb..e0c74591b6 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -181,8 +181,7 @@ class RedactionsTestCase(HomeserverTestCase):
)
def test_redact_event_as_moderator_ratelimit(self):
- """Tests that the correct ratelimiting is applied to redactions
- """
+ """Tests that the correct ratelimiting is applied to redactions"""
message_ids = []
# as a regular user, send messages to redact
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 10b1fbac69..b8285f3240 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -252,7 +252,8 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
mock_federation_client = Mock(spec=["backfill"])
self.hs = self.setup_test_homeserver(
- config=config, federation_client=mock_federation_client,
+ config=config,
+ federation_client=mock_federation_client,
)
return self.hs
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 0ebdf1415b..d2cce44032 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -260,7 +260,10 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+ self.banned_user_id,
+ room_id,
+ "m.room.member",
+ self.banned_user_id,
)
)
self.assertEqual(
@@ -292,7 +295,10 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+ self.banned_user_id,
+ room_id,
+ "m.room.member",
+ self.banned_user_id,
)
)
self.assertEqual(
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 0a5ca317ea..2ae896db1e 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -150,6 +150,8 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
event_id = resp["event_id"]
channel = self.make_request(
- "GET", "/events/" + event_id, access_token=self.token,
+ "GET",
+ "/events/" + event_id,
+ access_token=self.token,
)
self.assertEquals(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 49543d9acb..fb29eaed6f 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -611,7 +611,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# matrix access token, mxid, and device id.
login_token = params[2][1]
chan = self.make_request(
- "POST", "/login", content={"type": "m.login.token", "token": login_token},
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test")
@@ -619,7 +621,8 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_multi_sso_redirect_to_unknown(self):
"""An unknown IdP should cause a 400"""
channel = self.make_request(
- "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
+ "GET",
+ "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
)
self.assertEqual(channel.code, 400, channel.result)
@@ -719,7 +722,8 @@ class CASTestCase(unittest.HomeserverTestCase):
mocked_http_client.get_raw.side_effect = get_raw
self.hs = self.setup_test_homeserver(
- config=config, proxied_http_client=mocked_http_client,
+ config=config,
+ proxied_http_client=mocked_http_client,
)
return self.hs
@@ -1244,7 +1248,9 @@ class UsernamePickerTestCase(HomeserverTestCase):
# looks ok.
username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
self.assertIn(
- session_id, username_mapping_sessions, "session id not found in map",
+ session_id,
+ username_mapping_sessions,
+ "session id not found in map",
)
session = username_mapping_sessions[session_id]
self.assertEqual(session.remote_user_id, "tester")
@@ -1299,7 +1305,9 @@ class UsernamePickerTestCase(HomeserverTestCase):
# finally, submit the matrix login token to the login API, which gives us our
# matrix access token, mxid, and device id.
chan = self.make_request(
- "POST", "/login", content={"type": "m.login.token", "token": login_token},
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@bobby:test")
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index e59fa70baa..f3448c94dd 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -14,163 +14,11 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
-import json
-
-from mock import Mock
-
-from twisted.internet import defer
-
-import synapse.types
-from synapse.api.errors import AuthError, SynapseError
from synapse.rest import admin
from synapse.rest.client.v1 import login, profile, room
from tests import unittest
-from ....utils import MockHttpResource, setup_test_homeserver
-
-myid = "@1234ABCD:test"
-PATH_PREFIX = "/_matrix/client/r0"
-
-
-class MockHandlerProfileTestCase(unittest.TestCase):
- """ Tests rest layer of profile management.
-
- Todo: move these into ProfileTestCase
- """
-
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self.mock_handler = Mock(
- spec=[
- "get_displayname",
- "set_displayname",
- "get_avatar_url",
- "set_avatar_url",
- "check_profile_query_allowed",
- ]
- )
-
- self.mock_handler.get_displayname.return_value = defer.succeed(Mock())
- self.mock_handler.set_displayname.return_value = defer.succeed(Mock())
- self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock())
- self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock())
- self.mock_handler.check_profile_query_allowed.return_value = defer.succeed(
- Mock()
- )
-
- hs = yield setup_test_homeserver(
- self.addCleanup,
- "test",
- federation_http_client=None,
- resource_for_client=self.mock_resource,
- federation=Mock(),
- federation_client=Mock(),
- profile_handler=self.mock_handler,
- )
-
- async def _get_user_by_req(request=None, allow_guest=False):
- return synapse.types.create_requester(myid)
-
- hs.get_auth().get_user_by_req = _get_user_by_req
-
- profile.register_servlets(hs, self.mock_resource)
-
- @defer.inlineCallbacks
- def test_get_my_name(self):
- mocked_get = self.mock_handler.get_displayname
- mocked_get.return_value = defer.succeed("Frank")
-
- (code, response) = yield self.mock_resource.trigger(
- "GET", "/profile/%s/displayname" % (myid), None
- )
-
- self.assertEquals(200, code)
- self.assertEquals({"displayname": "Frank"}, response)
- self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD")
-
- @defer.inlineCallbacks
- def test_set_my_name(self):
- mocked_set = self.mock_handler.set_displayname
- mocked_set.return_value = defer.succeed(())
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT", "/profile/%s/displayname" % (myid), b'{"displayname": "Frank Jr."}'
- )
-
- self.assertEquals(200, code)
- self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
- self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
- self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.")
-
- @defer.inlineCallbacks
- def test_set_my_name_noauth(self):
- mocked_set = self.mock_handler.set_displayname
- mocked_set.side_effect = AuthError(400, "message")
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/profile/%s/displayname" % ("@4567:test"),
- b'{"displayname": "Frank Jr."}',
- )
-
- self.assertTrue(400 <= code < 499, msg="code %d is in the 4xx range" % (code))
-
- @defer.inlineCallbacks
- def test_get_other_name(self):
- mocked_get = self.mock_handler.get_displayname
- mocked_get.return_value = defer.succeed("Bob")
-
- (code, response) = yield self.mock_resource.trigger(
- "GET", "/profile/%s/displayname" % ("@opaque:elsewhere"), None
- )
-
- self.assertEquals(200, code)
- self.assertEquals({"displayname": "Bob"}, response)
-
- @defer.inlineCallbacks
- def test_set_other_name(self):
- mocked_set = self.mock_handler.set_displayname
- mocked_set.side_effect = SynapseError(400, "message")
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/profile/%s/displayname" % ("@opaque:elsewhere"),
- b'{"displayname":"bob"}',
- )
-
- self.assertTrue(400 <= code <= 499, msg="code %d is in the 4xx range" % (code))
-
- @defer.inlineCallbacks
- def test_get_my_avatar(self):
- mocked_get = self.mock_handler.get_avatar_url
- mocked_get.return_value = defer.succeed("http://my.server/me.png")
-
- (code, response) = yield self.mock_resource.trigger(
- "GET", "/profile/%s/avatar_url" % (myid), None
- )
-
- self.assertEquals(200, code)
- self.assertEquals({"avatar_url": "http://my.server/me.png"}, response)
- self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD")
-
- @defer.inlineCallbacks
- def test_set_my_avatar(self):
- mocked_set = self.mock_handler.set_avatar_url
- mocked_set.return_value = defer.succeed(())
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/profile/%s/avatar_url" % (myid),
- b'{"avatar_url": "http://my.server/pic.gif"}',
- )
-
- self.assertEquals(200, code)
- self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
- self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
- self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif")
-
class ProfileTestCase(unittest.HomeserverTestCase):
@@ -187,37 +35,122 @@ class ProfileTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.owner = self.register_user("owner", "pass")
self.owner_tok = self.login("owner", "pass")
+ self.other = self.register_user("other", "pass", displayname="Bob")
+
+ def test_get_displayname(self):
+ res = self._get_displayname()
+ self.assertEqual(res, "owner")
def test_set_displayname(self):
channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
- content=json.dumps({"displayname": "test"}),
+ content={"displayname": "test"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
- res = self.get_displayname()
+ res = self._get_displayname()
self.assertEqual(res, "test")
+ def test_set_displayname_noauth(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/displayname" % (self.owner,),
+ content={"displayname": "test"},
+ )
+ self.assertEqual(channel.code, 401, channel.result)
+
def test_set_displayname_too_long(self):
"""Attempts to set a stupid displayname should get a 400"""
channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
- content=json.dumps({"displayname": "test" * 100}),
+ content={"displayname": "test" * 100},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
- res = self.get_displayname()
+ res = self._get_displayname()
self.assertEqual(res, "owner")
- def get_displayname(self):
- channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,))
+ def test_get_displayname_other(self):
+ res = self._get_displayname(self.other)
+ self.assertEquals(res, "Bob")
+
+ def test_set_displayname_other(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/displayname" % (self.other,),
+ content={"displayname": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+
+ def test_get_avatar_url(self):
+ res = self._get_avatar_url()
+ self.assertIsNone(res)
+
+ def test_set_avatar_url(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.owner,),
+ content={"avatar_url": "http://my.server/pic.gif"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ res = self._get_avatar_url()
+ self.assertEqual(res, "http://my.server/pic.gif")
+
+ def test_set_avatar_url_noauth(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.owner,),
+ content={"avatar_url": "http://my.server/pic.gif"},
+ )
+ self.assertEqual(channel.code, 401, channel.result)
+
+ def test_set_avatar_url_too_long(self):
+ """Attempts to set a stupid avatar_url should get a 400"""
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.owner,),
+ content={"avatar_url": "http://my.server/pic.gif" * 100},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+
+ res = self._get_avatar_url()
+ self.assertIsNone(res)
+
+ def test_get_avatar_url_other(self):
+ res = self._get_avatar_url(self.other)
+ self.assertIsNone(res)
+
+ def test_set_avatar_url_other(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.other,),
+ content={"avatar_url": "http://my.server/pic.gif"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+
+ def _get_displayname(self, name=None):
+ channel = self.make_request(
+ "GET", "/profile/%s/displayname" % (name or self.owner,)
+ )
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["displayname"]
+ def _get_avatar_url(self, name=None):
+ channel = self.make_request(
+ "GET", "/profile/%s/avatar_url" % (name or self.owner,)
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body.get("avatar_url")
+
class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 2548b3a80c..ed65f645fc 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -46,7 +46,9 @@ class RoomBase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver(
- "red", federation_http_client=None, federation_client=Mock(),
+ "red",
+ federation_http_client=None,
+ federation_client=Mock(),
)
self.hs.get_federation_handler = Mock()
@@ -1480,7 +1482,9 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results = channel.json_body["search_categories"]["room_events"]["results"]
self.assertEqual(
- len(results), 2, [result["result"]["content"] for result in results],
+ len(results),
+ 2,
+ [result["result"]["content"] for result in results],
)
self.assertEqual(
results[0]["result"]["content"]["body"],
@@ -1515,7 +1519,9 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results = channel.json_body["search_categories"]["room_events"]["results"]
self.assertEqual(
- len(results), 4, [result["result"]["content"] for result in results],
+ len(results),
+ 4,
+ [result["result"]["content"] for result in results],
)
self.assertEqual(
results[0]["result"]["content"]["body"],
@@ -1562,7 +1568,9 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results = channel.json_body["search_categories"]["room_events"]["results"]
self.assertEqual(
- len(results), 1, [result["result"]["content"] for result in results],
+ len(results),
+ 1,
+ [result["result"]["content"] for result in results],
)
self.assertEqual(
results[0]["result"]["content"]["body"],
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 38c51525a3..329dbd06de 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -18,8 +18,6 @@
from mock import Mock
-from twisted.internet import defer
-
from synapse.rest.client.v1 import room
from synapse.types import UserID
@@ -39,7 +37,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "red", federation_http_client=None, federation_client=Mock(),
+ "red",
+ federation_http_client=None,
+ federation_client=Mock(),
)
self.event_source = hs.get_event_sources().sources["typing"]
@@ -60,32 +60,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_datastore().insert_client_ip = _insert_client_ip
- def get_room_members(room_id):
- if room_id == self.room_id:
- return defer.succeed([self.user])
- else:
- return defer.succeed([])
-
- @defer.inlineCallbacks
- def fetch_room_distributions_into(
- room_id, localusers=None, remotedomains=None, ignore_user=None
- ):
- members = yield get_room_members(room_id)
- for member in members:
- if ignore_user is not None and member == ignore_user:
- continue
-
- if hs.is_mine(member):
- if localusers is not None:
- localusers.add(member)
- else:
- if remotedomains is not None:
- remotedomains.add(member.domain)
-
- hs.get_room_member_handler().fetch_room_distributions_into = (
- fetch_room_distributions_into
- )
-
return hs
def prepare(self, reactor, clock, hs):
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index b1333df82d..8231a423f3 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -166,9 +166,12 @@ class RestHelper:
json.dumps(data).encode("utf8"),
)
- assert int(channel.result["code"]) == expect_code, (
- "Expected: %d, got: %d, resp: %r"
- % (expect_code, int(channel.result["code"]), channel.result["body"])
+ assert (
+ int(channel.result["code"]) == expect_code
+ ), "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ int(channel.result["code"]),
+ channel.result["body"],
)
self.auth_user_id = temp_id
@@ -201,9 +204,12 @@ class RestHelper:
json.dumps(content).encode("utf8"),
)
- assert int(channel.result["code"]) == expect_code, (
- "Expected: %d, got: %d, resp: %r"
- % (expect_code, int(channel.result["code"]), channel.result["body"])
+ assert (
+ int(channel.result["code"]) == expect_code
+ ), "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ int(channel.result["code"]),
+ channel.result["body"],
)
return channel.json_body
@@ -251,9 +257,12 @@ class RestHelper:
channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
- assert int(channel.result["code"]) == expect_code, (
- "Expected: %d, got: %d, resp: %r"
- % (expect_code, int(channel.result["code"]), channel.result["body"])
+ assert (
+ int(channel.result["code"]) == expect_code
+ ), "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ int(channel.result["code"]),
+ channel.result["body"],
)
return channel.json_body
@@ -447,7 +456,10 @@ class RestHelper:
return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
def complete_oidc_auth(
- self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
+ self,
+ oauth_uri: str,
+ cookies: Mapping[str, str],
+ user_info_dict: JsonDict,
) -> FakeChannel:
"""Mock out an OIDC authentication flow
@@ -491,7 +503,9 @@ class RestHelper:
(expected_uri, resp_obj) = expected_requests.pop(0)
assert uri == expected_uri
resp = FakeResponse(
- code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
+ code=200,
+ phrase=b"OK",
+ body=json.dumps(resp_obj).encode("utf-8"),
)
return resp
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 177dc476da..e72b61963d 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -75,8 +75,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
def test_basic_password_reset(self):
- """Test basic password reset flow
- """
+ """Test basic password reset flow"""
old_password = "monkey"
new_password = "kangeroo"
@@ -114,8 +113,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
@override_config({"rc_3pid_validation": {"burst_count": 3}})
def test_ratelimit_by_email(self):
- """Test that we ratelimit /requestToken for the same email.
- """
+ """Test that we ratelimit /requestToken for the same email."""
old_password = "monkey"
new_password = "kangeroo"
@@ -203,8 +201,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.attempt_wrong_password_login("kermit", old_password)
def test_cant_reset_password_without_clicking_link(self):
- """Test that we do actually need to click the link in the email
- """
+ """Test that we do actually need to click the link in the email"""
old_password = "monkey"
new_password = "kangeroo"
@@ -299,7 +296,9 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
if channel.code != 200:
raise HttpResponseException(
- channel.code, channel.result["reason"], channel.result["body"],
+ channel.code,
+ channel.result["reason"],
+ channel.result["body"],
)
return channel.json_body["sid"]
@@ -566,8 +565,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
@override_config({"rc_3pid_validation": {"burst_count": 3}})
def test_ratelimit_by_ip(self):
- """Tests that adding emails is ratelimited by IP
- """
+ """Tests that adding emails is ratelimited by IP"""
# We expect to be able to set three emails before getting ratelimited.
self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
@@ -580,8 +578,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(cm.exception.code, 429)
def test_add_email_if_disabled(self):
- """Test adding email to profile when doing so is disallowed
- """
+ """Test adding email to profile when doing so is disallowed"""
self.hs.config.enable_3pid_changes = False
client_secret = "foobar"
@@ -611,15 +608,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email(self):
- """Test deleting an email from profile
- """
+ """Test deleting an email from profile"""
# Add a threepid
self.get_success(
self.store.user_add_threepid(
@@ -641,15 +639,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email_if_disabled(self):
- """Test deleting an email from profile when disallowed
- """
+ """Test deleting an email from profile when disallowed"""
self.hs.config.enable_3pid_changes = False
# Add a threepid
@@ -675,7 +674,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -683,8 +684,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
def test_cant_add_email_without_clicking_link(self):
- """Test that we do actually need to click the link in the email
- """
+ """Test that we do actually need to click the link in the email"""
client_secret = "foobar"
session_id = self._request_token(self.email, client_secret)
@@ -710,7 +710,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -743,7 +745,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -788,7 +792,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Ensure not providing a next_link parameter still works
self._request_token(
- "something@example.com", "some_secret", next_link=None, expect_code=200,
+ "something@example.com",
+ "some_secret",
+ next_link=None,
+ expect_code=200,
)
self._request_token(
@@ -846,17 +853,27 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
if next_link:
body["next_link"] = next_link
- channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
+ channel = self.make_request(
+ "POST",
+ b"account/3pid/email/requestToken",
+ body,
+ )
if channel.code != expect_code:
raise HttpResponseException(
- channel.code, channel.result["reason"], channel.result["body"],
+ channel.code,
+ channel.result["reason"],
+ channel.result["body"],
)
return channel.json_body.get("sid")
def _request_token_invalid_email(
- self, email, expected_errcode, expected_error, client_secret="foobar",
+ self,
+ email,
+ expected_errcode,
+ expected_error,
+ client_secret="foobar",
):
channel = self.make_request(
"POST",
@@ -895,8 +912,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
return match.group(0)
def _add_email(self, request_email, expected_email):
- """Test adding an email to profile
- """
+ """Test adding an email to profile"""
previous_email_attempts = len(self.email_attempts)
client_secret = "foobar"
@@ -926,7 +942,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
# Get user
channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ "GET",
+ self.url_3pid,
+ access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 3f50c56745..501f09203f 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -102,7 +102,8 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
"""Ensure that fallback auth via a captcha works."""
# Returns a 401 as per the spec
channel = self.register(
- 401, {"username": "user", "type": "m.login.password", "password": "bar"},
+ 401,
+ {"username": "user", "type": "m.login.password", "password": "bar"},
)
# Grab the session
@@ -191,7 +192,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
) -> FakeChannel:
"""Delete an individual device."""
channel = self.make_request(
- "DELETE", "devices/" + device, body, access_token=access_token,
+ "DELETE",
+ "devices/" + device,
+ body,
+ access_token=access_token,
)
# Ensure the response is sane.
@@ -204,7 +208,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Note that this uses the delete_devices endpoint so that we can modify
# the payload half-way through some tests.
channel = self.make_request(
- "POST", "delete_devices", body, access_token=self.user_tok,
+ "POST",
+ "delete_devices",
+ body,
+ access_token=self.user_tok,
)
# Ensure the response is sane.
@@ -417,7 +424,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
# and now the delete request should succeed.
self.delete_device(
- self.user_tok, self.device_id, 200, body={"auth": {"session": session_id}},
+ self.user_tok,
+ self.device_id,
+ 200,
+ body={"auth": {"session": session_id}},
)
@skip_unless(HAS_OIDC, "requires OIDC")
@@ -443,8 +453,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_offers_both_flows_for_upgraded_user(self):
- """A user that had a password and then logged in with SSO should get both flows
- """
+ """A user that had a password and then logged in with SSO should get both flows"""
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
self.assertEqual(login_resp["user_id"], self.user)
@@ -459,8 +468,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_ui_auth_fails_for_incorrect_sso_user(self):
- """If the user tries to authenticate with the wrong SSO user, they get an error
- """
+ """If the user tries to authenticate with the wrong SSO user, they get an error"""
# log the user in
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
self.assertEqual(login_resp["user_id"], self.user)
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
index fba34def30..5ebc5707a5 100644
--- a/tests/rest/client/v2_alpha/test_password_policy.py
+++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -91,7 +91,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result,
+ channel.json_body["errcode"],
+ Codes.PASSWORD_TOO_SHORT,
+ channel.result,
)
def test_password_no_digit(self):
@@ -100,7 +102,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result,
+ channel.json_body["errcode"],
+ Codes.PASSWORD_NO_DIGIT,
+ channel.result,
)
def test_password_no_symbol(self):
@@ -109,7 +113,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result,
+ channel.json_body["errcode"],
+ Codes.PASSWORD_NO_SYMBOL,
+ channel.result,
)
def test_password_no_uppercase(self):
@@ -118,7 +124,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result,
+ channel.json_body["errcode"],
+ Codes.PASSWORD_NO_UPPERCASE,
+ channel.result,
)
def test_password_no_lowercase(self):
@@ -127,7 +135,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(
- channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result,
+ channel.json_body["errcode"],
+ Codes.PASSWORD_NO_LOWERCASE,
+ channel.result,
)
def test_password_compliant(self):
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index bd574077e7..7c457754f1 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -83,14 +83,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
def test_deny_membership(self):
- """Test that we deny relations on membership events
- """
+ """Test that we deny relations on membership events"""
channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
self.assertEquals(400, channel.code, channel.json_body)
def test_deny_double_react(self):
- """Test that we deny relations on membership events
- """
+ """Test that we deny relations on membership events"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
self.assertEquals(200, channel.code, channel.json_body)
@@ -98,8 +96,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(400, channel.code, channel.json_body)
def test_basic_paginate_relations(self):
- """Tests that calling pagination API correctly the latest relations.
- """
+ """Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
self.assertEquals(200, channel.code, channel.json_body)
@@ -174,8 +171,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(found_event_ids, expected_event_ids)
def test_aggregation_pagination_groups(self):
- """Test that we can paginate annotation groups correctly.
- """
+ """Test that we can paginate annotation groups correctly."""
# We need to create ten separate users to send each reaction.
access_tokens = [self.user_token, self.user2_token]
@@ -240,8 +236,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(sent_groups, found_groups)
def test_aggregation_pagination_within_group(self):
- """Test that we can paginate within an annotation group.
- """
+ """Test that we can paginate within an annotation group."""
# We need to create ten separate users to send each reaction.
access_tokens = [self.user_token, self.user2_token]
@@ -311,8 +306,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(found_event_ids, expected_event_ids)
def test_aggregation(self):
- """Test that annotations get correctly aggregated.
- """
+ """Test that annotations get correctly aggregated."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
@@ -344,8 +338,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
def test_aggregation_redactions(self):
- """Test that annotations get correctly aggregated after a redaction.
- """
+ """Test that annotations get correctly aggregated after a redaction."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
@@ -379,8 +372,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
def test_aggregation_must_be_annotation(self):
- """Test that aggregations must be annotations.
- """
+ """Test that aggregations must be annotations."""
channel = self.make_request(
"GET",
@@ -437,8 +429,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
def test_edit(self):
- """Test that a simple edit works.
- """
+ """Test that a simple edit works."""
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
channel = self._send_relation(
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 7f68032d9d..899f4902d7 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -481,13 +481,19 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Check that room name changes increase the unread counter.
self.helper.send_state(
- self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
+ self.room_id,
+ "m.room.name",
+ {"name": "my super room"},
+ tok=self.tok2,
)
self._check_unread_count(1)
# Check that room topic changes increase the unread counter.
self.helper.send_state(
- self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
+ self.room_id,
+ "m.room.topic",
+ {"topic": "welcome!!!"},
+ tok=self.tok2,
)
self._check_unread_count(2)
@@ -497,7 +503,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Check that custom events with a body increase the unread counter.
self.helper.send_event(
- self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
+ self.room_id,
+ "org.matrix.custom_type",
+ {"body": "hello"},
+ tok=self.tok2,
)
self._check_unread_count(4)
@@ -536,14 +545,18 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
"""Syncs and compares the unread count with the expected value."""
channel = self.make_request(
- "GET", self.url % self.next_batch, access_token=self.tok,
+ "GET",
+ self.url % self.next_batch,
+ access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
room_entry = channel.json_body["rooms"]["join"][self.room_id]
self.assertEqual(
- room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
+ room_entry["org.matrix.msc2654.unread_count"],
+ expected_count,
+ room_entry,
)
# Store the next batch for the next request.
diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/v2_alpha/test_upgrade_room.py
new file mode 100644
index 0000000000..d890d11863
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_upgrade_room.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+from synapse.config.server import DEFAULT_ROOM_VERSION
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+
+from tests import unittest
+from tests.server import FakeChannel
+
+
+class UpgradeRoomTest(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ room_upgrade_rest_servlet.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.handler = hs.get_user_directory_handler()
+
+ self.creator = self.register_user("creator", "pass")
+ self.creator_token = self.login(self.creator, "pass")
+
+ self.other = self.register_user("user", "pass")
+ self.other_token = self.login(self.other, "pass")
+
+ self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_token)
+ self.helper.join(self.room_id, self.other, tok=self.other_token)
+
+ def _upgrade_room(self, token: Optional[str] = None) -> FakeChannel:
+ # We never want a cached response.
+ self.reactor.advance(5 * 60 + 1)
+
+ return self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/upgrade" % self.room_id,
+ # This will upgrade a room to the same version, but that's fine.
+ content={"new_version": DEFAULT_ROOM_VERSION},
+ access_token=token or self.creator_token,
+ )
+
+ def test_upgrade(self):
+ """
+ Upgrading a room should work fine.
+ """
+ channel = self._upgrade_room()
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertIn("replacement_room", channel.json_body)
+
+ def test_not_in_room(self):
+ """
+ Upgrading a room should work fine.
+ """
+ # THe user isn't in the room.
+ roomless = self.register_user("roomless", "pass")
+ roomless_token = self.login(roomless, "pass")
+
+ channel = self._upgrade_room(roomless_token)
+ self.assertEquals(403, channel.code, channel.result)
+
+ def test_power_levels(self):
+ """
+ Another user can upgrade the room if their power level is increased.
+ """
+ # The other user doesn't have the proper power level.
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(403, channel.code, channel.result)
+
+ # Increase the power levels so that this user can upgrade.
+ power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.creator_token,
+ )
+ power_levels["users"][self.other] = 100
+ self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ body=power_levels,
+ tok=self.creator_token,
+ )
+
+ # The upgrade should succeed!
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(200, channel.code, channel.result)
+
+ def test_power_levels_user_default(self):
+ """
+ Another user can upgrade the room if the default power level for users is increased.
+ """
+ # The other user doesn't have the proper power level.
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(403, channel.code, channel.result)
+
+ # Increase the power levels so that this user can upgrade.
+ power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.creator_token,
+ )
+ power_levels["users_default"] = 100
+ self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ body=power_levels,
+ tok=self.creator_token,
+ )
+
+ # The upgrade should succeed!
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(200, channel.code, channel.result)
+
+ def test_power_levels_tombstone(self):
+ """
+ Another user can upgrade the room if they can send the tombstone event.
+ """
+ # The other user doesn't have the proper power level.
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(403, channel.code, channel.result)
+
+ # Increase the power levels so that this user can upgrade.
+ power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.creator_token,
+ )
+ power_levels["events"]["m.room.tombstone"] = 0
+ self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ body=power_levels,
+ tok=self.creator_token,
+ )
+
+ # The upgrade should succeed!
+ channel = self._upgrade_room(self.other_token)
+ self.assertEquals(200, channel.code, channel.result)
+
+ power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.creator_token,
+ )
+ self.assertNotIn(self.other, power_levels["users"])
|