diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 0a5ca317ea..2ae896db1e 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -150,6 +150,8 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
event_id = resp["event_id"]
channel = self.make_request(
- "GET", "/events/" + event_id, access_token=self.token,
+ "GET",
+ "/events/" + event_id,
+ access_token=self.token,
)
self.assertEquals(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 2672ce24c6..fb29eaed6f 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -15,7 +15,7 @@
import time
import urllib.parse
-from typing import Any, Dict, Union
+from typing import Any, Dict, List, Union
from urllib.parse import urlencode
from mock import Mock
@@ -29,8 +29,7 @@ from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices, register
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
-from synapse.rest.synapse.client.pick_idp import PickIdpResource
-from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.types import create_requester
from tests import unittest
@@ -75,6 +74,10 @@ TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
# the query params in TEST_CLIENT_REDIRECT_URL
EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
+# (possibly experimental) login flows we expect to appear in the list after the normal
+# ones
+ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+
class LoginRestServletTestCase(unittest.HomeserverTestCase):
@@ -419,13 +422,61 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
return config
def create_resource_dict(self) -> Dict[str, Resource]:
- from synapse.rest.oidc import OIDCResource
-
d = super().create_resource_dict()
- d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
- d["/_synapse/oidc"] = OIDCResource(self.hs)
+ d.update(build_synapse_client_resource_tree(self.hs))
return d
+ def test_get_login_flows(self):
+ """GET /login should return password and SSO flows"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ expected_flows = [
+ {"type": "m.login.cas"},
+ {"type": "m.login.sso"},
+ {"type": "m.login.token"},
+ {"type": "m.login.password"},
+ ] + ADDITIONAL_LOGIN_FLOWS
+
+ self.assertCountEqual(channel.json_body["flows"], expected_flows)
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_get_msc2858_login_flows(self):
+ """The SSO flow should include IdP info if MSC2858 is enabled"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # stick the flows results in a dict by type
+ flow_results = {} # type: Dict[str, Any]
+ for f in channel.json_body["flows"]:
+ flow_type = f["type"]
+ self.assertNotIn(
+ flow_type, flow_results, "duplicate flow type %s" % (flow_type,)
+ )
+ flow_results[flow_type] = f
+
+ self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned")
+ sso_flow = flow_results.pop("m.login.sso")
+ # we should have a set of IdPs
+ self.assertCountEqual(
+ sso_flow["org.matrix.msc2858.identity_providers"],
+ [
+ {"id": "cas", "name": "CAS"},
+ {"id": "saml", "name": "SAML"},
+ {"id": "oidc-idp1", "name": "IDP1"},
+ {"id": "oidc", "name": "OIDC"},
+ ],
+ )
+
+ # the rest of the flows are simple
+ expected_flows = [
+ {"type": "m.login.cas"},
+ {"type": "m.login.token"},
+ {"type": "m.login.password"},
+ ] + ADDITIONAL_LOGIN_FLOWS
+
+ self.assertCountEqual(flow_results.values(), expected_flows)
+
def test_multi_sso_redirect(self):
"""/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker
@@ -442,13 +493,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
# parse the form to check it has fields assumed elsewhere in this class
+ html = channel.result["body"].decode("utf-8")
p = TestHtmlParser()
- p.feed(channel.result["body"].decode("utf-8"))
+ p.feed(html)
p.close()
- self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"])
+ # there should be a link for each href
+ returned_idps = [] # type: List[str]
+ for link in p.links:
+ path, query = link.split("?", 1)
+ self.assertEqual(path, "pick_idp")
+ params = urllib.parse.parse_qs(query)
+ self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL])
+ returned_idps.append(params["idp"][0])
- self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
+ self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
def test_multi_sso_redirect_to_cas(self):
"""If CAS is chosen, should redirect to the CAS server"""
@@ -552,7 +611,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# matrix access token, mxid, and device id.
login_token = params[2][1]
chan = self.make_request(
- "POST", "/login", content={"type": "m.login.token", "token": login_token},
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test")
@@ -560,9 +621,47 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_multi_sso_redirect_to_unknown(self):
"""An unknown IdP should cause a 400"""
channel = self.make_request(
- "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
+ "GET",
+ "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+
+ def test_client_idp_redirect_msc2858_disabled(self):
+ """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_client_idp_redirect_to_unknown(self):
+ """If the client tries to pick an unknown IdP, return a 404"""
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+ )
+ self.assertEqual(channel.code, 404, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_client_idp_redirect_to_oidc(self):
+ """If the client pick a known IdP, redirect to it"""
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+ )
+
+ self.assertEqual(channel.code, 302, channel.result)
+ oidc_uri = channel.headers.getRawHeaders("Location")[0]
+ oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+ # it should redirect us to the auth page of the OIDC server
+ self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
@staticmethod
def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
@@ -584,10 +683,12 @@ class CASTestCase(unittest.HomeserverTestCase):
self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
config = self.default_config()
+ config["public_baseurl"] = (
+ config.get("public_baseurl") or "https://matrix.goodserver.com:8448"
+ )
config["cas_config"] = {
"enabled": True,
"server_url": CAS_SERVER,
- "service_url": "https://matrix.goodserver.com:8448",
}
cas_user_id = "username"
@@ -621,7 +722,8 @@ class CASTestCase(unittest.HomeserverTestCase):
mocked_http_client.get_raw.side_effect = get_raw
self.hs = self.setup_test_homeserver(
- config=config, proxied_http_client=mocked_http_client,
+ config=config,
+ proxied_http_client=mocked_http_client,
)
return self.hs
@@ -1119,11 +1221,8 @@ class UsernamePickerTestCase(HomeserverTestCase):
return config
def create_resource_dict(self) -> Dict[str, Resource]:
- from synapse.rest.oidc import OIDCResource
-
d = super().create_resource_dict()
- d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
- d["/_synapse/oidc"] = OIDCResource(self.hs)
+ d.update(build_synapse_client_resource_tree(self.hs))
return d
def test_username_picker(self):
@@ -1137,7 +1236,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
# that should redirect to the username picker
self.assertEqual(channel.code, 302, channel.result)
picker_url = channel.headers.getRawHeaders("Location")[0]
- self.assertEqual(picker_url, "/_synapse/client/pick_username")
+ self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
# ... with a username_mapping_session cookie
cookies = {} # type: Dict[str,str]
@@ -1149,7 +1248,9 @@ class UsernamePickerTestCase(HomeserverTestCase):
# looks ok.
username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
self.assertIn(
- session_id, username_mapping_sessions, "session id not found in map",
+ session_id,
+ username_mapping_sessions,
+ "session id not found in map",
)
session = username_mapping_sessions[session_id]
self.assertEqual(session.remote_user_id, "tester")
@@ -1161,12 +1262,11 @@ class UsernamePickerTestCase(HomeserverTestCase):
self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
# Now, submit a username to the username picker, which should serve a redirect
- # back to the client
- submit_path = picker_url + "/submit"
+ # to the completion page
content = urlencode({b"username": b"bobby"}).encode("utf8")
chan = self.make_request(
"POST",
- path=submit_path,
+ path=picker_url,
content=content,
content_is_form=True,
custom_headers=[
@@ -1178,6 +1278,16 @@ class UsernamePickerTestCase(HomeserverTestCase):
)
self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
+
+ # send a request to the completion page, which should 302 to the client redirectUrl
+ chan = self.make_request(
+ "GET",
+ path=location_headers[0],
+ custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
+ )
+ self.assertEqual(chan.code, 302, chan.result)
+ location_headers = chan.headers.getRawHeaders("Location")
+
# ensure that the returned location matches the requested redirect URL
path, query = location_headers[0].split("?", 1)
self.assertEqual(path, "https://x")
@@ -1195,7 +1305,9 @@ class UsernamePickerTestCase(HomeserverTestCase):
# finally, submit the matrix login token to the login API, which gives us our
# matrix access token, mxid, and device id.
chan = self.make_request(
- "POST", "/login", content={"type": "m.login.token", "token": login_token},
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@bobby:test")
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index e59fa70baa..f3448c94dd 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -14,163 +14,11 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
-import json
-
-from mock import Mock
-
-from twisted.internet import defer
-
-import synapse.types
-from synapse.api.errors import AuthError, SynapseError
from synapse.rest import admin
from synapse.rest.client.v1 import login, profile, room
from tests import unittest
-from ....utils import MockHttpResource, setup_test_homeserver
-
-myid = "@1234ABCD:test"
-PATH_PREFIX = "/_matrix/client/r0"
-
-
-class MockHandlerProfileTestCase(unittest.TestCase):
- """ Tests rest layer of profile management.
-
- Todo: move these into ProfileTestCase
- """
-
- @defer.inlineCallbacks
- def setUp(self):
- self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
- self.mock_handler = Mock(
- spec=[
- "get_displayname",
- "set_displayname",
- "get_avatar_url",
- "set_avatar_url",
- "check_profile_query_allowed",
- ]
- )
-
- self.mock_handler.get_displayname.return_value = defer.succeed(Mock())
- self.mock_handler.set_displayname.return_value = defer.succeed(Mock())
- self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock())
- self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock())
- self.mock_handler.check_profile_query_allowed.return_value = defer.succeed(
- Mock()
- )
-
- hs = yield setup_test_homeserver(
- self.addCleanup,
- "test",
- federation_http_client=None,
- resource_for_client=self.mock_resource,
- federation=Mock(),
- federation_client=Mock(),
- profile_handler=self.mock_handler,
- )
-
- async def _get_user_by_req(request=None, allow_guest=False):
- return synapse.types.create_requester(myid)
-
- hs.get_auth().get_user_by_req = _get_user_by_req
-
- profile.register_servlets(hs, self.mock_resource)
-
- @defer.inlineCallbacks
- def test_get_my_name(self):
- mocked_get = self.mock_handler.get_displayname
- mocked_get.return_value = defer.succeed("Frank")
-
- (code, response) = yield self.mock_resource.trigger(
- "GET", "/profile/%s/displayname" % (myid), None
- )
-
- self.assertEquals(200, code)
- self.assertEquals({"displayname": "Frank"}, response)
- self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD")
-
- @defer.inlineCallbacks
- def test_set_my_name(self):
- mocked_set = self.mock_handler.set_displayname
- mocked_set.return_value = defer.succeed(())
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT", "/profile/%s/displayname" % (myid), b'{"displayname": "Frank Jr."}'
- )
-
- self.assertEquals(200, code)
- self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
- self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
- self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.")
-
- @defer.inlineCallbacks
- def test_set_my_name_noauth(self):
- mocked_set = self.mock_handler.set_displayname
- mocked_set.side_effect = AuthError(400, "message")
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/profile/%s/displayname" % ("@4567:test"),
- b'{"displayname": "Frank Jr."}',
- )
-
- self.assertTrue(400 <= code < 499, msg="code %d is in the 4xx range" % (code))
-
- @defer.inlineCallbacks
- def test_get_other_name(self):
- mocked_get = self.mock_handler.get_displayname
- mocked_get.return_value = defer.succeed("Bob")
-
- (code, response) = yield self.mock_resource.trigger(
- "GET", "/profile/%s/displayname" % ("@opaque:elsewhere"), None
- )
-
- self.assertEquals(200, code)
- self.assertEquals({"displayname": "Bob"}, response)
-
- @defer.inlineCallbacks
- def test_set_other_name(self):
- mocked_set = self.mock_handler.set_displayname
- mocked_set.side_effect = SynapseError(400, "message")
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/profile/%s/displayname" % ("@opaque:elsewhere"),
- b'{"displayname":"bob"}',
- )
-
- self.assertTrue(400 <= code <= 499, msg="code %d is in the 4xx range" % (code))
-
- @defer.inlineCallbacks
- def test_get_my_avatar(self):
- mocked_get = self.mock_handler.get_avatar_url
- mocked_get.return_value = defer.succeed("http://my.server/me.png")
-
- (code, response) = yield self.mock_resource.trigger(
- "GET", "/profile/%s/avatar_url" % (myid), None
- )
-
- self.assertEquals(200, code)
- self.assertEquals({"avatar_url": "http://my.server/me.png"}, response)
- self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD")
-
- @defer.inlineCallbacks
- def test_set_my_avatar(self):
- mocked_set = self.mock_handler.set_avatar_url
- mocked_set.return_value = defer.succeed(())
-
- (code, response) = yield self.mock_resource.trigger(
- "PUT",
- "/profile/%s/avatar_url" % (myid),
- b'{"avatar_url": "http://my.server/pic.gif"}',
- )
-
- self.assertEquals(200, code)
- self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
- self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
- self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif")
-
class ProfileTestCase(unittest.HomeserverTestCase):
@@ -187,37 +35,122 @@ class ProfileTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.owner = self.register_user("owner", "pass")
self.owner_tok = self.login("owner", "pass")
+ self.other = self.register_user("other", "pass", displayname="Bob")
+
+ def test_get_displayname(self):
+ res = self._get_displayname()
+ self.assertEqual(res, "owner")
def test_set_displayname(self):
channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
- content=json.dumps({"displayname": "test"}),
+ content={"displayname": "test"},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 200, channel.result)
- res = self.get_displayname()
+ res = self._get_displayname()
self.assertEqual(res, "test")
+ def test_set_displayname_noauth(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/displayname" % (self.owner,),
+ content={"displayname": "test"},
+ )
+ self.assertEqual(channel.code, 401, channel.result)
+
def test_set_displayname_too_long(self):
"""Attempts to set a stupid displayname should get a 400"""
channel = self.make_request(
"PUT",
"/profile/%s/displayname" % (self.owner,),
- content=json.dumps({"displayname": "test" * 100}),
+ content={"displayname": "test" * 100},
access_token=self.owner_tok,
)
self.assertEqual(channel.code, 400, channel.result)
- res = self.get_displayname()
+ res = self._get_displayname()
self.assertEqual(res, "owner")
- def get_displayname(self):
- channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,))
+ def test_get_displayname_other(self):
+ res = self._get_displayname(self.other)
+ self.assertEquals(res, "Bob")
+
+ def test_set_displayname_other(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/displayname" % (self.other,),
+ content={"displayname": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+
+ def test_get_avatar_url(self):
+ res = self._get_avatar_url()
+ self.assertIsNone(res)
+
+ def test_set_avatar_url(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.owner,),
+ content={"avatar_url": "http://my.server/pic.gif"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ res = self._get_avatar_url()
+ self.assertEqual(res, "http://my.server/pic.gif")
+
+ def test_set_avatar_url_noauth(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.owner,),
+ content={"avatar_url": "http://my.server/pic.gif"},
+ )
+ self.assertEqual(channel.code, 401, channel.result)
+
+ def test_set_avatar_url_too_long(self):
+ """Attempts to set a stupid avatar_url should get a 400"""
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.owner,),
+ content={"avatar_url": "http://my.server/pic.gif" * 100},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+
+ res = self._get_avatar_url()
+ self.assertIsNone(res)
+
+ def test_get_avatar_url_other(self):
+ res = self._get_avatar_url(self.other)
+ self.assertIsNone(res)
+
+ def test_set_avatar_url_other(self):
+ channel = self.make_request(
+ "PUT",
+ "/profile/%s/avatar_url" % (self.other,),
+ content={"avatar_url": "http://my.server/pic.gif"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+
+ def _get_displayname(self, name=None):
+ channel = self.make_request(
+ "GET", "/profile/%s/displayname" % (name or self.owner,)
+ )
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["displayname"]
+ def _get_avatar_url(self, name=None):
+ channel = self.make_request(
+ "GET", "/profile/%s/avatar_url" % (name or self.owner,)
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body.get("avatar_url")
+
class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index d4e3165436..ed65f645fc 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -46,7 +46,9 @@ class RoomBase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver(
- "red", federation_http_client=None, federation_client=Mock(),
+ "red",
+ federation_http_client=None,
+ federation_client=Mock(),
)
self.hs.get_federation_handler = Mock()
@@ -616,6 +618,41 @@ class RoomMemberStateTestCase(RoomBase):
self.assertEquals(json.loads(content), channel.json_body)
+class RoomInviteRatelimitTestCase(RoomBase):
+ user_id = "@sid1:red"
+
+ servlets = [
+ admin.register_servlets,
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ @unittest.override_config(
+ {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invites_by_rooms_ratelimit(self):
+ """Tests that invites in a room are actually rate-limited."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ for i in range(3):
+ self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,))
+
+ self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429)
+
+ @unittest.override_config(
+ {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invites_by_users_ratelimit(self):
+ """Tests that invites to a specific user are actually rate-limited."""
+
+ for i in range(3):
+ room_id = self.helper.create_room_as(self.user_id)
+ self.helper.invite(room_id, self.user_id, "@other-users:red")
+
+ room_id = self.helper.create_room_as(self.user_id)
+ self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
+
+
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
@@ -1445,7 +1482,9 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results = channel.json_body["search_categories"]["room_events"]["results"]
self.assertEqual(
- len(results), 2, [result["result"]["content"] for result in results],
+ len(results),
+ 2,
+ [result["result"]["content"] for result in results],
)
self.assertEqual(
results[0]["result"]["content"]["body"],
@@ -1480,7 +1519,9 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results = channel.json_body["search_categories"]["room_events"]["results"]
self.assertEqual(
- len(results), 4, [result["result"]["content"] for result in results],
+ len(results),
+ 4,
+ [result["result"]["content"] for result in results],
)
self.assertEqual(
results[0]["result"]["content"]["body"],
@@ -1527,7 +1568,9 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results = channel.json_body["search_categories"]["room_events"]["results"]
self.assertEqual(
- len(results), 1, [result["result"]["content"] for result in results],
+ len(results),
+ 1,
+ [result["result"]["content"] for result in results],
)
self.assertEqual(
results[0]["result"]["content"]["body"],
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 38c51525a3..329dbd06de 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -18,8 +18,6 @@
from mock import Mock
-from twisted.internet import defer
-
from synapse.rest.client.v1 import room
from synapse.types import UserID
@@ -39,7 +37,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "red", federation_http_client=None, federation_client=Mock(),
+ "red",
+ federation_http_client=None,
+ federation_client=Mock(),
)
self.event_source = hs.get_event_sources().sources["typing"]
@@ -60,32 +60,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_datastore().insert_client_ip = _insert_client_ip
- def get_room_members(room_id):
- if room_id == self.room_id:
- return defer.succeed([self.user])
- else:
- return defer.succeed([])
-
- @defer.inlineCallbacks
- def fetch_room_distributions_into(
- room_id, localusers=None, remotedomains=None, ignore_user=None
- ):
- members = yield get_room_members(room_id)
- for member in members:
- if ignore_user is not None and member == ignore_user:
- continue
-
- if hs.is_mine(member):
- if localusers is not None:
- localusers.add(member)
- else:
- if remotedomains is not None:
- remotedomains.add(member.domain)
-
- hs.get_room_member_handler().fetch_room_distributions_into = (
- fetch_room_distributions_into
- )
-
return hs
def prepare(self, reactor, clock, hs):
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index b1333df82d..8231a423f3 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -166,9 +166,12 @@ class RestHelper:
json.dumps(data).encode("utf8"),
)
- assert int(channel.result["code"]) == expect_code, (
- "Expected: %d, got: %d, resp: %r"
- % (expect_code, int(channel.result["code"]), channel.result["body"])
+ assert (
+ int(channel.result["code"]) == expect_code
+ ), "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ int(channel.result["code"]),
+ channel.result["body"],
)
self.auth_user_id = temp_id
@@ -201,9 +204,12 @@ class RestHelper:
json.dumps(content).encode("utf8"),
)
- assert int(channel.result["code"]) == expect_code, (
- "Expected: %d, got: %d, resp: %r"
- % (expect_code, int(channel.result["code"]), channel.result["body"])
+ assert (
+ int(channel.result["code"]) == expect_code
+ ), "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ int(channel.result["code"]),
+ channel.result["body"],
)
return channel.json_body
@@ -251,9 +257,12 @@ class RestHelper:
channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
- assert int(channel.result["code"]) == expect_code, (
- "Expected: %d, got: %d, resp: %r"
- % (expect_code, int(channel.result["code"]), channel.result["body"])
+ assert (
+ int(channel.result["code"]) == expect_code
+ ), "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ int(channel.result["code"]),
+ channel.result["body"],
)
return channel.json_body
@@ -447,7 +456,10 @@ class RestHelper:
return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
def complete_oidc_auth(
- self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
+ self,
+ oauth_uri: str,
+ cookies: Mapping[str, str],
+ user_info_dict: JsonDict,
) -> FakeChannel:
"""Mock out an OIDC authentication flow
@@ -491,7 +503,9 @@ class RestHelper:
(expected_uri, resp_obj) = expected_requests.pop(0)
assert uri == expected_uri
resp = FakeResponse(
- code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
+ code=200,
+ phrase=b"OK",
+ body=json.dumps(resp_obj).encode("utf-8"),
)
return resp
|