diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 7c47aa7e0a..2a217b1ce0 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1445,6 +1445,90 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+ def test_context_as_non_admin(self):
+ """
+ Test that, without being admin, one cannot use the context admin API
+ """
+ # Create a room.
+ user_id = self.register_user("test", "test")
+ user_tok = self.login("test", "test")
+
+ self.register_user("test_2", "test")
+ user_tok_2 = self.login("test_2", "test")
+
+ room_id = self.helper.create_room_as(user_id, tok=user_tok)
+
+ # Populate the room with events.
+ events = []
+ for i in range(30):
+ events.append(
+ self.helper.send_event(
+ room_id, "com.example.test", content={"index": i}, tok=user_tok
+ )
+ )
+
+ # Now attempt to find the context using the admin API without being admin.
+ midway = (len(events) - 1) // 2
+ for tok in [user_tok, user_tok_2]:
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/context/%s"
+ % (room_id, events[midway]["event_id"]),
+ access_token=tok,
+ )
+ self.assertEquals(
+ 403, int(channel.result["code"]), msg=channel.result["body"]
+ )
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_context_as_admin(self):
+ """
+ Test that, as admin, we can find the context of an event without having joined the room.
+ """
+
+ # Create a room. We're not part of it.
+ user_id = self.register_user("test", "test")
+ user_tok = self.login("test", "test")
+ room_id = self.helper.create_room_as(user_id, tok=user_tok)
+
+ # Populate the room with events.
+ events = []
+ for i in range(30):
+ events.append(
+ self.helper.send_event(
+ room_id, "com.example.test", content={"index": i}, tok=user_tok
+ )
+ )
+
+ # Now let's fetch the context for this room.
+ midway = (len(events) - 1) // 2
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/context/%s"
+ % (room_id, events[midway]["event_id"]),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(
+ channel.json_body["event"]["event_id"], events[midway]["event_id"]
+ )
+
+ for i, found_event in enumerate(channel.json_body["events_before"]):
+ for j, posted_event in enumerate(events):
+ if found_event["event_id"] == posted_event["event_id"]:
+ self.assertTrue(j < midway)
+ break
+ else:
+ self.fail("Event %s from events_before not found" % j)
+
+ for i, found_event in enumerate(channel.json_body["events_after"]):
+ for j, posted_event in enumerate(events):
+ if found_event["event_id"] == posted_event["event_id"]:
+ self.assertTrue(j > midway)
+ break
+ else:
+ self.fail("Event %s from events_after not found" % j)
+
class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index bfcb786af8..49543d9acb 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -15,7 +15,7 @@
import time
import urllib.parse
-from typing import Any, Dict, Union
+from typing import Any, Dict, List, Union
from urllib.parse import urlencode
from mock import Mock
@@ -493,13 +493,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
# parse the form to check it has fields assumed elsewhere in this class
+ html = channel.result["body"].decode("utf-8")
p = TestHtmlParser()
- p.feed(channel.result["body"].decode("utf-8"))
+ p.feed(html)
p.close()
- self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"])
+ # there should be a link for each href
+ returned_idps = [] # type: List[str]
+ for link in p.links:
+ path, query = link.split("?", 1)
+ self.assertEqual(path, "pick_idp")
+ params = urllib.parse.parse_qs(query)
+ self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL])
+ returned_idps.append(params["idp"][0])
- self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
+ self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
def test_multi_sso_redirect_to_cas(self):
"""If CAS is chosen, should redirect to the CAS server"""
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 38c51525a3..f6f3b9a356 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -18,8 +18,6 @@
from mock import Mock
-from twisted.internet import defer
-
from synapse.rest.client.v1 import room
from synapse.types import UserID
@@ -60,32 +58,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_datastore().insert_client_ip = _insert_client_ip
- def get_room_members(room_id):
- if room_id == self.room_id:
- return defer.succeed([self.user])
- else:
- return defer.succeed([])
-
- @defer.inlineCallbacks
- def fetch_room_distributions_into(
- room_id, localusers=None, remotedomains=None, ignore_user=None
- ):
- members = yield get_room_members(room_id)
- for member in members:
- if ignore_user is not None and member == ignore_user:
- continue
-
- if hs.is_mine(member):
- if localusers is not None:
- localusers.add(member)
- else:
- if remotedomains is not None:
- remotedomains.add(member.domain)
-
- hs.get_room_member_handler().fetch_room_distributions_into = (
- fetch_room_distributions_into
- )
-
return hs
def prepare(self, reactor, clock, hs):
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index a6c6985173..c279eb49e3 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -30,6 +30,8 @@ from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.logging.context import make_deferred_yieldable
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage
@@ -37,6 +39,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from tests import unittest
from tests.server import FakeSite, make_request
+from tests.utils import default_config
class MediaStorageTests(unittest.HomeserverTestCase):
@@ -398,3 +401,94 @@ class MediaRepoTests(unittest.HomeserverTestCase):
headers.getRawHeaders(b"X-Robots-Tag"),
[b"noindex, nofollow, noarchive, noimageindex"],
)
+
+
+class TestSpamChecker:
+ """A spam checker module that rejects all media that includes the bytes
+ `evil`.
+ """
+
+ def __init__(self, config, api):
+ self.config = config
+ self.api = api
+
+ def parse_config(config):
+ return config
+
+ async def check_event_for_spam(self, foo):
+ return False # allow all events
+
+ async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+ return True # allow all invites
+
+ async def user_may_create_room(self, userid):
+ return True # allow all room creations
+
+ async def user_may_create_room_alias(self, userid, room_alias):
+ return True # allow all room aliases
+
+ async def user_may_publish_room(self, userid, room_id):
+ return True # allow publishing of all rooms
+
+ async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
+ buf = BytesIO()
+ await file_wrapper.write_chunks_to(buf.write)
+
+ return b"evil" in buf.getvalue()
+
+
+class SpamCheckerTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ admin.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.user = self.register_user("user", "pass")
+ self.tok = self.login("user", "pass")
+
+ # Allow for uploading and downloading to/from the media repo
+ self.media_repo = hs.get_media_repository_resource()
+ self.download_resource = self.media_repo.children[b"download"]
+ self.upload_resource = self.media_repo.children[b"upload"]
+
+ def default_config(self):
+ config = default_config("test")
+
+ config.update(
+ {
+ "spam_checker": [
+ {
+ "module": TestSpamChecker.__module__ + ".TestSpamChecker",
+ "config": {},
+ }
+ ]
+ }
+ )
+
+ return config
+
+ def test_upload_innocent(self):
+ """Attempt to upload some innocent data that should be allowed.
+ """
+
+ image_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ self.helper.upload_media(
+ self.upload_resource, image_data, tok=self.tok, expect_code=200
+ )
+
+ def test_upload_ban(self):
+ """Attempt to upload some data that includes bytes "evil", which should
+ get rejected by the spam checker.
+ """
+
+ data = b"Some evil data"
+
+ self.helper.upload_media(
+ self.upload_resource, data, tok=self.tok, expect_code=400
+ )
|