diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index f4f5569a89..2a217b1ce0 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1180,6 +1180,21 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.json_body["total"], 3)
+ def test_room_state(self):
+ """Test that room state can be requested correctly"""
+ # Create two test rooms
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,)
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIn("state", channel.json_body)
+ # testing that the state events match is painful and not done here. We assume that
+ # the create_room already does the right thing, so no need to verify that we got
+ # the state events it created.
+
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index e2bb945453..ceb4ad2366 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -15,7 +15,7 @@
import time
import urllib.parse
-from typing import Any, Dict, Union
+from typing import Any, Dict, List, Union
from urllib.parse import urlencode
from mock import Mock
@@ -29,8 +29,7 @@ from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices, register
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
-from synapse.rest.synapse.client.pick_idp import PickIdpResource
-from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.types import create_requester
from tests import unittest
@@ -423,11 +422,8 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
return config
def create_resource_dict(self) -> Dict[str, Resource]:
- from synapse.rest.oidc import OIDCResource
-
d = super().create_resource_dict()
- d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
- d["/_synapse/oidc"] = OIDCResource(self.hs)
+ d.update(build_synapse_client_resource_tree(self.hs))
return d
def test_get_login_flows(self):
@@ -497,13 +493,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
# parse the form to check it has fields assumed elsewhere in this class
+ html = channel.result["body"].decode("utf-8")
p = TestHtmlParser()
- p.feed(channel.result["body"].decode("utf-8"))
+ p.feed(html)
p.close()
- self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"])
+ # there should be a link for each href
+ returned_idps = [] # type: List[str]
+ for link in p.links:
+ path, query = link.split("?", 1)
+ self.assertEqual(path, "pick_idp")
+ params = urllib.parse.parse_qs(query)
+ self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL])
+ returned_idps.append(params["idp"][0])
- self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
+ self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
def test_multi_sso_redirect_to_cas(self):
"""If CAS is chosen, should redirect to the CAS server"""
@@ -1211,11 +1215,8 @@ class UsernamePickerTestCase(HomeserverTestCase):
return config
def create_resource_dict(self) -> Dict[str, Resource]:
- from synapse.rest.oidc import OIDCResource
-
d = super().create_resource_dict()
- d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
- d["/_synapse/oidc"] = OIDCResource(self.hs)
+ d.update(build_synapse_client_resource_tree(self.hs))
return d
def test_username_picker(self):
@@ -1229,7 +1230,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
# that should redirect to the username picker
self.assertEqual(channel.code, 302, channel.result)
picker_url = channel.headers.getRawHeaders("Location")[0]
- self.assertEqual(picker_url, "/_synapse/client/pick_username")
+ self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
# ... with a username_mapping_session cookie
cookies = {} # type: Dict[str,str]
@@ -1253,12 +1254,11 @@ class UsernamePickerTestCase(HomeserverTestCase):
self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
# Now, submit a username to the username picker, which should serve a redirect
- # back to the client
- submit_path = picker_url + "/submit"
+ # to the completion page
content = urlencode({b"username": b"bobby"}).encode("utf8")
chan = self.make_request(
"POST",
- path=submit_path,
+ path=picker_url,
content=content,
content_is_form=True,
custom_headers=[
@@ -1270,6 +1270,16 @@ class UsernamePickerTestCase(HomeserverTestCase):
)
self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
+
+ # send a request to the completion page, which should 302 to the client redirectUrl
+ chan = self.make_request(
+ "GET",
+ path=location_headers[0],
+ custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
+ )
+ self.assertEqual(chan.code, 302, chan.result)
+ location_headers = chan.headers.getRawHeaders("Location")
+
# ensure that the returned location matches the requested redirect URL
path, query = location_headers[0].split("?", 1)
self.assertEqual(path, "https://x")
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index d4e3165436..2548b3a80c 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -616,6 +616,41 @@ class RoomMemberStateTestCase(RoomBase):
self.assertEquals(json.loads(content), channel.json_body)
+class RoomInviteRatelimitTestCase(RoomBase):
+ user_id = "@sid1:red"
+
+ servlets = [
+ admin.register_servlets,
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ @unittest.override_config(
+ {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invites_by_rooms_ratelimit(self):
+ """Tests that invites in a room are actually rate-limited."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ for i in range(3):
+ self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,))
+
+ self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429)
+
+ @unittest.override_config(
+ {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invites_by_users_ratelimit(self):
+ """Tests that invites to a specific user are actually rate-limited."""
+
+ for i in range(3):
+ room_id = self.helper.create_room_as(self.user_id)
+ self.helper.invite(room_id, self.user_id, "@other-users:red")
+
+ room_id = self.helper.create_room_as(self.user_id)
+ self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
+
+
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index cb87b80e33..177dc476da 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -24,7 +24,7 @@ import pkg_resources
import synapse.rest.admin
from synapse.api.constants import LoginType, Membership
-from synapse.api.errors import Codes
+from synapse.api.errors import Codes, HttpResponseException
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
@@ -112,6 +112,56 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the old password
self.attempt_wrong_password_login("kermit", old_password)
+ @override_config({"rc_3pid_validation": {"burst_count": 3}})
+ def test_ratelimit_by_email(self):
+ """Test that we ratelimit /requestToken for the same email.
+ """
+ old_password = "monkey"
+ new_password = "kangeroo"
+
+ user_id = self.register_user("kermit", old_password)
+ self.login("kermit", old_password)
+
+ email = "test1@example.com"
+
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=user_id,
+ medium="email",
+ address=email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ def reset(ip):
+ client_secret = "foobar"
+ session_id = self._request_token(email, client_secret, ip)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ self._reset_password(new_password, session_id, client_secret)
+
+ self.email_attempts.clear()
+
+ # We expect to be able to make three requests before getting rate
+ # limited.
+ #
+ # We change IPs to ensure that we're not being ratelimited due to the
+ # same IP
+ reset("127.0.0.1")
+ reset("127.0.0.2")
+ reset("127.0.0.3")
+
+ with self.assertRaises(HttpResponseException) as cm:
+ reset("127.0.0.4")
+
+ self.assertEqual(cm.exception.code, 429)
+
def test_basic_password_reset_canonicalise_email(self):
"""Test basic password reset flow
Request password reset with different spelling
@@ -239,13 +289,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(session_id)
- def _request_token(self, email, client_secret):
+ def _request_token(self, email, client_secret, ip="127.0.0.1"):
channel = self.make_request(
"POST",
b"account/password/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
+ client_ip=ip,
)
- self.assertEquals(200, channel.code, channel.result)
+
+ if channel.code != 200:
+ raise HttpResponseException(
+ channel.code, channel.result["reason"], channel.result["body"],
+ )
return channel.json_body["sid"]
@@ -509,6 +564,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_address_trim(self):
self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
+ @override_config({"rc_3pid_validation": {"burst_count": 3}})
+ def test_ratelimit_by_ip(self):
+ """Tests that adding emails is ratelimited by IP
+ """
+
+ # We expect to be able to set three emails before getting ratelimited.
+ self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
+ self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
+ self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
+
+ with self.assertRaises(HttpResponseException) as cm:
+ self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
+
+ self.assertEqual(cm.exception.code, 429)
+
def test_add_email_if_disabled(self):
"""Test adding email to profile when doing so is disallowed
"""
@@ -777,7 +847,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
body["next_link"] = next_link
channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
- self.assertEquals(expect_code, channel.code, channel.result)
+
+ if channel.code != expect_code:
+ raise HttpResponseException(
+ channel.code, channel.result["reason"], channel.result["body"],
+ )
return channel.json_body.get("sid")
@@ -823,10 +897,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def _add_email(self, request_email, expected_email):
"""Test adding an email to profile
"""
+ previous_email_attempts = len(self.email_attempts)
+
client_secret = "foobar"
session_id = self._request_token(request_email, client_secret)
- self.assertEquals(len(self.email_attempts), 1)
+ self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1)
link = self._get_link_from_email()
self._validate_token(link)
@@ -855,4 +931,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])
+
+ threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
+ self.assertIn(expected_email, threepids)
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index a6488a3d29..3f50c56745 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -22,7 +22,7 @@ from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import auth, devices, register
-from synapse.rest.oidc import OIDCResource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.types import JsonDict, UserID
from tests import unittest
@@ -173,9 +173,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
def create_resource_dict(self):
resource_dict = super().create_resource_dict()
- if HAS_OIDC:
- # mount the OIDC resource at /_synapse/oidc
- resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
+ resource_dict.update(build_synapse_client_resource_tree(self.hs))
return resource_dict
def prepare(self, reactor, clock, hs):
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index a6c6985173..c279eb49e3 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -30,6 +30,8 @@ from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.logging.context import make_deferred_yieldable
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage
@@ -37,6 +39,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from tests import unittest
from tests.server import FakeSite, make_request
+from tests.utils import default_config
class MediaStorageTests(unittest.HomeserverTestCase):
@@ -398,3 +401,94 @@ class MediaRepoTests(unittest.HomeserverTestCase):
headers.getRawHeaders(b"X-Robots-Tag"),
[b"noindex, nofollow, noarchive, noimageindex"],
)
+
+
+class TestSpamChecker:
+ """A spam checker module that rejects all media that includes the bytes
+ `evil`.
+ """
+
+ def __init__(self, config, api):
+ self.config = config
+ self.api = api
+
+ def parse_config(config):
+ return config
+
+ async def check_event_for_spam(self, foo):
+ return False # allow all events
+
+ async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+ return True # allow all invites
+
+ async def user_may_create_room(self, userid):
+ return True # allow all room creations
+
+ async def user_may_create_room_alias(self, userid, room_alias):
+ return True # allow all room aliases
+
+ async def user_may_publish_room(self, userid, room_id):
+ return True # allow publishing of all rooms
+
+ async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
+ buf = BytesIO()
+ await file_wrapper.write_chunks_to(buf.write)
+
+ return b"evil" in buf.getvalue()
+
+
+class SpamCheckerTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ admin.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.user = self.register_user("user", "pass")
+ self.tok = self.login("user", "pass")
+
+ # Allow for uploading and downloading to/from the media repo
+ self.media_repo = hs.get_media_repository_resource()
+ self.download_resource = self.media_repo.children[b"download"]
+ self.upload_resource = self.media_repo.children[b"upload"]
+
+ def default_config(self):
+ config = default_config("test")
+
+ config.update(
+ {
+ "spam_checker": [
+ {
+ "module": TestSpamChecker.__module__ + ".TestSpamChecker",
+ "config": {},
+ }
+ ]
+ }
+ )
+
+ return config
+
+ def test_upload_innocent(self):
+ """Attempt to upload some innocent data that should be allowed.
+ """
+
+ image_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ self.helper.upload_media(
+ self.upload_resource, image_data, tok=self.tok, expect_code=200
+ )
+
+ def test_upload_ban(self):
+ """Attempt to upload some data that includes bytes "evil", which should
+ get rejected by the spam checker.
+ """
+
+ data = b"Some evil data"
+
+ self.helper.upload_media(
+ self.upload_resource, data, tok=self.tok, expect_code=400
+ )
|