summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/admin/test_room.py84
-rw-r--r--tests/rest/client/v1/test_login.py16
-rw-r--r--tests/rest/client/v1/test_typing.py28
-rw-r--r--tests/rest/media/v1/test_media_storage.py94
4 files changed, 190 insertions, 32 deletions
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 7c47aa7e0a..2a217b1ce0 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1445,6 +1445,90 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
 
+    def test_context_as_non_admin(self):
+        """
+        Test that, without being admin, one cannot use the context admin API
+        """
+        # Create a room.
+        user_id = self.register_user("test", "test")
+        user_tok = self.login("test", "test")
+
+        self.register_user("test_2", "test")
+        user_tok_2 = self.login("test_2", "test")
+
+        room_id = self.helper.create_room_as(user_id, tok=user_tok)
+
+        # Populate the room with events.
+        events = []
+        for i in range(30):
+            events.append(
+                self.helper.send_event(
+                    room_id, "com.example.test", content={"index": i}, tok=user_tok
+                )
+            )
+
+        # Now attempt to find the context using the admin API without being admin.
+        midway = (len(events) - 1) // 2
+        for tok in [user_tok, user_tok_2]:
+            channel = self.make_request(
+                "GET",
+                "/_synapse/admin/v1/rooms/%s/context/%s"
+                % (room_id, events[midway]["event_id"]),
+                access_token=tok,
+            )
+            self.assertEquals(
+                403, int(channel.result["code"]), msg=channel.result["body"]
+            )
+            self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_context_as_admin(self):
+        """
+        Test that, as admin, we can find the context of an event without having joined the room.
+        """
+
+        # Create a room. We're not part of it.
+        user_id = self.register_user("test", "test")
+        user_tok = self.login("test", "test")
+        room_id = self.helper.create_room_as(user_id, tok=user_tok)
+
+        # Populate the room with events.
+        events = []
+        for i in range(30):
+            events.append(
+                self.helper.send_event(
+                    room_id, "com.example.test", content={"index": i}, tok=user_tok
+                )
+            )
+
+        # Now let's fetch the context for this room.
+        midway = (len(events) - 1) // 2
+        channel = self.make_request(
+            "GET",
+            "/_synapse/admin/v1/rooms/%s/context/%s"
+            % (room_id, events[midway]["event_id"]),
+            access_token=self.admin_user_tok,
+        )
+        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEquals(
+            channel.json_body["event"]["event_id"], events[midway]["event_id"]
+        )
+
+        for i, found_event in enumerate(channel.json_body["events_before"]):
+            for j, posted_event in enumerate(events):
+                if found_event["event_id"] == posted_event["event_id"]:
+                    self.assertTrue(j < midway)
+                    break
+            else:
+                self.fail("Event %s from events_before not found" % j)
+
+        for i, found_event in enumerate(channel.json_body["events_after"]):
+            for j, posted_event in enumerate(events):
+                if found_event["event_id"] == posted_event["event_id"]:
+                    self.assertTrue(j > midway)
+                    break
+            else:
+                self.fail("Event %s from events_after not found" % j)
+
 
 class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
     servlets = [
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index bfcb786af8..49543d9acb 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -15,7 +15,7 @@
 
 import time
 import urllib.parse
-from typing import Any, Dict, Union
+from typing import Any, Dict, List, Union
 from urllib.parse import urlencode
 
 from mock import Mock
@@ -493,13 +493,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
 
         # parse the form to check it has fields assumed elsewhere in this class
+        html = channel.result["body"].decode("utf-8")
         p = TestHtmlParser()
-        p.feed(channel.result["body"].decode("utf-8"))
+        p.feed(html)
         p.close()
 
-        self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"])
+        # there should be a link for each href
+        returned_idps = []  # type: List[str]
+        for link in p.links:
+            path, query = link.split("?", 1)
+            self.assertEqual(path, "pick_idp")
+            params = urllib.parse.parse_qs(query)
+            self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL])
+            returned_idps.append(params["idp"][0])
 
-        self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
+        self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
 
     def test_multi_sso_redirect_to_cas(self):
         """If CAS is chosen, should redirect to the CAS server"""
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 38c51525a3..f6f3b9a356 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -18,8 +18,6 @@
 
 from mock import Mock
 
-from twisted.internet import defer
-
 from synapse.rest.client.v1 import room
 from synapse.types import UserID
 
@@ -60,32 +58,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
 
         hs.get_datastore().insert_client_ip = _insert_client_ip
 
-        def get_room_members(room_id):
-            if room_id == self.room_id:
-                return defer.succeed([self.user])
-            else:
-                return defer.succeed([])
-
-        @defer.inlineCallbacks
-        def fetch_room_distributions_into(
-            room_id, localusers=None, remotedomains=None, ignore_user=None
-        ):
-            members = yield get_room_members(room_id)
-            for member in members:
-                if ignore_user is not None and member == ignore_user:
-                    continue
-
-                if hs.is_mine(member):
-                    if localusers is not None:
-                        localusers.add(member)
-                else:
-                    if remotedomains is not None:
-                        remotedomains.add(member.domain)
-
-        hs.get_room_member_handler().fetch_room_distributions_into = (
-            fetch_room_distributions_into
-        )
-
         return hs
 
     def prepare(self, reactor, clock, hs):
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index a6c6985173..c279eb49e3 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -30,6 +30,8 @@ from twisted.internet import defer
 from twisted.internet.defer import Deferred
 
 from synapse.logging.context import make_deferred_yieldable
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.filepath import MediaFilePaths
 from synapse.rest.media.v1.media_storage import MediaStorage
@@ -37,6 +39,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
 
 from tests import unittest
 from tests.server import FakeSite, make_request
+from tests.utils import default_config
 
 
 class MediaStorageTests(unittest.HomeserverTestCase):
@@ -398,3 +401,94 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             headers.getRawHeaders(b"X-Robots-Tag"),
             [b"noindex, nofollow, noarchive, noimageindex"],
         )
+
+
+class TestSpamChecker:
+    """A spam checker module that rejects all media that includes the bytes
+    `evil`.
+    """
+
+    def __init__(self, config, api):
+        self.config = config
+        self.api = api
+
+    def parse_config(config):
+        return config
+
+    async def check_event_for_spam(self, foo):
+        return False  # allow all events
+
+    async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+        return True  # allow all invites
+
+    async def user_may_create_room(self, userid):
+        return True  # allow all room creations
+
+    async def user_may_create_room_alias(self, userid, room_alias):
+        return True  # allow all room aliases
+
+    async def user_may_publish_room(self, userid, room_id):
+        return True  # allow publishing of all rooms
+
+    async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
+        buf = BytesIO()
+        await file_wrapper.write_chunks_to(buf.write)
+
+        return b"evil" in buf.getvalue()
+
+
+class SpamCheckerTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        login.register_servlets,
+        admin.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.user = self.register_user("user", "pass")
+        self.tok = self.login("user", "pass")
+
+        # Allow for uploading and downloading to/from the media repo
+        self.media_repo = hs.get_media_repository_resource()
+        self.download_resource = self.media_repo.children[b"download"]
+        self.upload_resource = self.media_repo.children[b"upload"]
+
+    def default_config(self):
+        config = default_config("test")
+
+        config.update(
+            {
+                "spam_checker": [
+                    {
+                        "module": TestSpamChecker.__module__ + ".TestSpamChecker",
+                        "config": {},
+                    }
+                ]
+            }
+        )
+
+        return config
+
+    def test_upload_innocent(self):
+        """Attempt to upload some innocent data that should be allowed.
+        """
+
+        image_data = unhexlify(
+            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+            b"0000001f15c4890000000a49444154789c63000100000500010d"
+            b"0a2db40000000049454e44ae426082"
+        )
+
+        self.helper.upload_media(
+            self.upload_resource, image_data, tok=self.tok, expect_code=200
+        )
+
+    def test_upload_ban(self):
+        """Attempt to upload some data that includes bytes "evil", which should
+        get rejected by the spam checker.
+        """
+
+        data = b"Some evil data"
+
+        self.helper.upload_media(
+            self.upload_resource, data, tok=self.tok, expect_code=400
+        )