summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/admin/test_room.py15
-rw-r--r--tests/rest/client/v1/test_login.py46
-rw-r--r--tests/rest/client/v1/test_rooms.py35
-rw-r--r--tests/rest/client/v2_alpha/test_account.py90
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py6
-rw-r--r--tests/rest/media/v1/test_media_storage.py94
6 files changed, 258 insertions, 28 deletions
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index f4f5569a89..2a217b1ce0 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1180,6 +1180,21 @@ class RoomTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.json_body["total"], 3)
 
+    def test_room_state(self):
+        """Test that room state can be requested correctly"""
+        # Create two test rooms
+        room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,)
+        channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIn("state", channel.json_body)
+        # testing that the state events match is painful and not done here. We assume that
+        # the create_room already does the right thing, so no need to verify that we got
+        # the state events it created.
+
 
 class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
 
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index e2bb945453..ceb4ad2366 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -15,7 +15,7 @@
 
 import time
 import urllib.parse
-from typing import Any, Dict, Union
+from typing import Any, Dict, List, Union
 from urllib.parse import urlencode
 
 from mock import Mock
@@ -29,8 +29,7 @@ from synapse.appservice import ApplicationService
 from synapse.rest.client.v1 import login, logout
 from synapse.rest.client.v2_alpha import devices, register
 from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
-from synapse.rest.synapse.client.pick_idp import PickIdpResource
-from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.types import create_requester
 
 from tests import unittest
@@ -423,11 +422,8 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         return config
 
     def create_resource_dict(self) -> Dict[str, Resource]:
-        from synapse.rest.oidc import OIDCResource
-
         d = super().create_resource_dict()
-        d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
-        d["/_synapse/oidc"] = OIDCResource(self.hs)
+        d.update(build_synapse_client_resource_tree(self.hs))
         return d
 
     def test_get_login_flows(self):
@@ -497,13 +493,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
 
         # parse the form to check it has fields assumed elsewhere in this class
+        html = channel.result["body"].decode("utf-8")
         p = TestHtmlParser()
-        p.feed(channel.result["body"].decode("utf-8"))
+        p.feed(html)
         p.close()
 
-        self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"])
+        # there should be a link for each href
+        returned_idps = []  # type: List[str]
+        for link in p.links:
+            path, query = link.split("?", 1)
+            self.assertEqual(path, "pick_idp")
+            params = urllib.parse.parse_qs(query)
+            self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL])
+            returned_idps.append(params["idp"][0])
 
-        self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
+        self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
 
     def test_multi_sso_redirect_to_cas(self):
         """If CAS is chosen, should redirect to the CAS server"""
@@ -1211,11 +1215,8 @@ class UsernamePickerTestCase(HomeserverTestCase):
         return config
 
     def create_resource_dict(self) -> Dict[str, Resource]:
-        from synapse.rest.oidc import OIDCResource
-
         d = super().create_resource_dict()
-        d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
-        d["/_synapse/oidc"] = OIDCResource(self.hs)
+        d.update(build_synapse_client_resource_tree(self.hs))
         return d
 
     def test_username_picker(self):
@@ -1229,7 +1230,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
         # that should redirect to the username picker
         self.assertEqual(channel.code, 302, channel.result)
         picker_url = channel.headers.getRawHeaders("Location")[0]
-        self.assertEqual(picker_url, "/_synapse/client/pick_username")
+        self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
 
         # ... with a username_mapping_session cookie
         cookies = {}  # type: Dict[str,str]
@@ -1253,12 +1254,11 @@ class UsernamePickerTestCase(HomeserverTestCase):
         self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
 
         # Now, submit a username to the username picker, which should serve a redirect
-        # back to the client
-        submit_path = picker_url + "/submit"
+        # to the completion page
         content = urlencode({b"username": b"bobby"}).encode("utf8")
         chan = self.make_request(
             "POST",
-            path=submit_path,
+            path=picker_url,
             content=content,
             content_is_form=True,
             custom_headers=[
@@ -1270,6 +1270,16 @@ class UsernamePickerTestCase(HomeserverTestCase):
         )
         self.assertEqual(chan.code, 302, chan.result)
         location_headers = chan.headers.getRawHeaders("Location")
+
+        # send a request to the completion page, which should 302 to the client redirectUrl
+        chan = self.make_request(
+            "GET",
+            path=location_headers[0],
+            custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
+        )
+        self.assertEqual(chan.code, 302, chan.result)
+        location_headers = chan.headers.getRawHeaders("Location")
+
         # ensure that the returned location matches the requested redirect URL
         path, query = location_headers[0].split("?", 1)
         self.assertEqual(path, "https://x")
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index d4e3165436..2548b3a80c 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -616,6 +616,41 @@ class RoomMemberStateTestCase(RoomBase):
         self.assertEquals(json.loads(content), channel.json_body)
 
 
+class RoomInviteRatelimitTestCase(RoomBase):
+    user_id = "@sid1:red"
+
+    servlets = [
+        admin.register_servlets,
+        profile.register_servlets,
+        room.register_servlets,
+    ]
+
+    @unittest.override_config(
+        {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_invites_by_rooms_ratelimit(self):
+        """Tests that invites in a room are actually rate-limited."""
+        room_id = self.helper.create_room_as(self.user_id)
+
+        for i in range(3):
+            self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,))
+
+        self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429)
+
+    @unittest.override_config(
+        {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_invites_by_users_ratelimit(self):
+        """Tests that invites to a specific user are actually rate-limited."""
+
+        for i in range(3):
+            room_id = self.helper.create_room_as(self.user_id)
+            self.helper.invite(room_id, self.user_id, "@other-users:red")
+
+        room_id = self.helper.create_room_as(self.user_id)
+        self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
+
+
 class RoomJoinRatelimitTestCase(RoomBase):
     user_id = "@sid1:red"
 
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index cb87b80e33..177dc476da 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -24,7 +24,7 @@ import pkg_resources
 
 import synapse.rest.admin
 from synapse.api.constants import LoginType, Membership
-from synapse.api.errors import Codes
+from synapse.api.errors import Codes, HttpResponseException
 from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import account, register
 from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
@@ -112,6 +112,56 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         # Assert we can't log in with the old password
         self.attempt_wrong_password_login("kermit", old_password)
 
+    @override_config({"rc_3pid_validation": {"burst_count": 3}})
+    def test_ratelimit_by_email(self):
+        """Test that we ratelimit /requestToken for the same email.
+        """
+        old_password = "monkey"
+        new_password = "kangeroo"
+
+        user_id = self.register_user("kermit", old_password)
+        self.login("kermit", old_password)
+
+        email = "test1@example.com"
+
+        # Add a threepid
+        self.get_success(
+            self.store.user_add_threepid(
+                user_id=user_id,
+                medium="email",
+                address=email,
+                validated_at=0,
+                added_at=0,
+            )
+        )
+
+        def reset(ip):
+            client_secret = "foobar"
+            session_id = self._request_token(email, client_secret, ip)
+
+            self.assertEquals(len(self.email_attempts), 1)
+            link = self._get_link_from_email()
+
+            self._validate_token(link)
+
+            self._reset_password(new_password, session_id, client_secret)
+
+            self.email_attempts.clear()
+
+        # We expect to be able to make three requests before getting rate
+        # limited.
+        #
+        # We change IPs to ensure that we're not being ratelimited due to the
+        # same IP
+        reset("127.0.0.1")
+        reset("127.0.0.2")
+        reset("127.0.0.3")
+
+        with self.assertRaises(HttpResponseException) as cm:
+            reset("127.0.0.4")
+
+        self.assertEqual(cm.exception.code, 429)
+
     def test_basic_password_reset_canonicalise_email(self):
         """Test basic password reset flow
         Request password reset with different spelling
@@ -239,13 +289,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
 
         self.assertIsNotNone(session_id)
 
-    def _request_token(self, email, client_secret):
+    def _request_token(self, email, client_secret, ip="127.0.0.1"):
         channel = self.make_request(
             "POST",
             b"account/password/email/requestToken",
             {"client_secret": client_secret, "email": email, "send_attempt": 1},
+            client_ip=ip,
         )
-        self.assertEquals(200, channel.code, channel.result)
+
+        if channel.code != 200:
+            raise HttpResponseException(
+                channel.code, channel.result["reason"], channel.result["body"],
+            )
 
         return channel.json_body["sid"]
 
@@ -509,6 +564,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
     def test_address_trim(self):
         self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
 
+    @override_config({"rc_3pid_validation": {"burst_count": 3}})
+    def test_ratelimit_by_ip(self):
+        """Tests that adding emails is ratelimited by IP
+        """
+
+        # We expect to be able to set three emails before getting ratelimited.
+        self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
+        self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
+        self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
+
+        with self.assertRaises(HttpResponseException) as cm:
+            self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
+
+        self.assertEqual(cm.exception.code, 429)
+
     def test_add_email_if_disabled(self):
         """Test adding email to profile when doing so is disallowed
         """
@@ -777,7 +847,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             body["next_link"] = next_link
 
         channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
-        self.assertEquals(expect_code, channel.code, channel.result)
+
+        if channel.code != expect_code:
+            raise HttpResponseException(
+                channel.code, channel.result["reason"], channel.result["body"],
+            )
 
         return channel.json_body.get("sid")
 
@@ -823,10 +897,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
     def _add_email(self, request_email, expected_email):
         """Test adding an email to profile
         """
+        previous_email_attempts = len(self.email_attempts)
+
         client_secret = "foobar"
         session_id = self._request_token(request_email, client_secret)
 
-        self.assertEquals(len(self.email_attempts), 1)
+        self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1)
         link = self._get_link_from_email()
 
         self._validate_token(link)
@@ -855,4 +931,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
-        self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])
+
+        threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
+        self.assertIn(expected_email, threepids)
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index a6488a3d29..3f50c56745 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -22,7 +22,7 @@ from synapse.api.constants import LoginType
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.rest.client.v1 import login
 from synapse.rest.client.v2_alpha import auth, devices, register
-from synapse.rest.oidc import OIDCResource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.types import JsonDict, UserID
 
 from tests import unittest
@@ -173,9 +173,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
     def create_resource_dict(self):
         resource_dict = super().create_resource_dict()
-        if HAS_OIDC:
-            # mount the OIDC resource at /_synapse/oidc
-            resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
+        resource_dict.update(build_synapse_client_resource_tree(self.hs))
         return resource_dict
 
     def prepare(self, reactor, clock, hs):
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index a6c6985173..c279eb49e3 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -30,6 +30,8 @@ from twisted.internet import defer
 from twisted.internet.defer import Deferred
 
 from synapse.logging.context import make_deferred_yieldable
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.filepath import MediaFilePaths
 from synapse.rest.media.v1.media_storage import MediaStorage
@@ -37,6 +39,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
 
 from tests import unittest
 from tests.server import FakeSite, make_request
+from tests.utils import default_config
 
 
 class MediaStorageTests(unittest.HomeserverTestCase):
@@ -398,3 +401,94 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             headers.getRawHeaders(b"X-Robots-Tag"),
             [b"noindex, nofollow, noarchive, noimageindex"],
         )
+
+
+class TestSpamChecker:
+    """A spam checker module that rejects all media that includes the bytes
+    `evil`.
+    """
+
+    def __init__(self, config, api):
+        self.config = config
+        self.api = api
+
+    def parse_config(config):
+        return config
+
+    async def check_event_for_spam(self, foo):
+        return False  # allow all events
+
+    async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+        return True  # allow all invites
+
+    async def user_may_create_room(self, userid):
+        return True  # allow all room creations
+
+    async def user_may_create_room_alias(self, userid, room_alias):
+        return True  # allow all room aliases
+
+    async def user_may_publish_room(self, userid, room_id):
+        return True  # allow publishing of all rooms
+
+    async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
+        buf = BytesIO()
+        await file_wrapper.write_chunks_to(buf.write)
+
+        return b"evil" in buf.getvalue()
+
+
+class SpamCheckerTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        login.register_servlets,
+        admin.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.user = self.register_user("user", "pass")
+        self.tok = self.login("user", "pass")
+
+        # Allow for uploading and downloading to/from the media repo
+        self.media_repo = hs.get_media_repository_resource()
+        self.download_resource = self.media_repo.children[b"download"]
+        self.upload_resource = self.media_repo.children[b"upload"]
+
+    def default_config(self):
+        config = default_config("test")
+
+        config.update(
+            {
+                "spam_checker": [
+                    {
+                        "module": TestSpamChecker.__module__ + ".TestSpamChecker",
+                        "config": {},
+                    }
+                ]
+            }
+        )
+
+        return config
+
+    def test_upload_innocent(self):
+        """Attempt to upload some innocent data that should be allowed.
+        """
+
+        image_data = unhexlify(
+            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+            b"0000001f15c4890000000a49444154789c63000100000500010d"
+            b"0a2db40000000049454e44ae426082"
+        )
+
+        self.helper.upload_media(
+            self.upload_resource, image_data, tok=self.tok, expect_code=200
+        )
+
+    def test_upload_ban(self):
+        """Attempt to upload some data that includes bytes "evil", which should
+        get rejected by the spam checker.
+        """
+
+        data = b"Some evil data"
+
+        self.helper.upload_media(
+            self.upload_resource, data, tok=self.tok, expect_code=400
+        )