diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 1e9ba3a201..e2915eb7b1 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -269,8 +269,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
one will be randomly generated.
Returns:
A tuple of (user_id, access_token).
- Raises:
- RegistrationError if there was a problem registering.
"""
if localpart is None:
raise SynapseError(400, "Request must include user id")
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 758ee071a5..4cbe9784ed 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -32,8 +32,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
def test_wait_for_sync_for_user_auth_blocking(self):
- user_id1 = "@user1:server"
- user_id2 = "@user2:server"
+ user_id1 = "@user1:test"
+ user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1)
self.reactor.advance(100) # So we get not 0 time
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index b68e9fe082..b1b037006d 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -115,13 +115,13 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def test_invites(self):
self.persist(type="m.room.create", key="", creator=USER_ID)
- self.check("get_invited_rooms_for_user", [USER_ID_2], [])
+ self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
self.replicate()
self.check(
- "get_invited_rooms_for_user",
+ "get_invited_rooms_for_local_user",
[USER_ID_2],
[
RoomsForUser(
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 6ceb483aa8..7a7e898843 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -14,11 +14,17 @@
# limitations under the License.
import json
+import os
+import urllib.parse
+from binascii import unhexlify
from mock import Mock
+from twisted.internet.defer import Deferred
+
import synapse.rest.admin
from synapse.http.server import JsonResource
+from synapse.logging.context import make_deferred_yieldable
from synapse.rest.admin import VersionServlet
from synapse.rest.client.v1 import events, login, room
from synapse.rest.client.v2_alpha import groups
@@ -346,3 +352,338 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
test_purge_room.skip = "Disabled because it's currently broken"
+
+
+class QuarantineMediaTestCase(unittest.HomeserverTestCase):
+ """Test /quarantine_media admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.admin.register_servlets_for_media_repo,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.hs = hs
+
+ # Allow for uploading and downloading to/from the media repo
+ self.media_repo = hs.get_media_repository_resource()
+ self.download_resource = self.media_repo.children[b"download"]
+ self.upload_resource = self.media_repo.children[b"upload"]
+ self.image_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ def make_homeserver(self, reactor, clock):
+
+ self.fetches = []
+
+ def get_file(destination, path, output_stream, args=None, max_size=None):
+ """
+ Returns tuple[int,dict,str,int] of file length, response headers,
+ absolute URI, and response code.
+ """
+
+ def write_to(r):
+ data, response = r
+ output_stream.write(data)
+ return response
+
+ d = Deferred()
+ d.addCallback(write_to)
+ self.fetches.append((d, destination, path, args))
+ return make_deferred_yieldable(d)
+
+ client = Mock()
+ client.get_file = get_file
+
+ self.storage_path = self.mktemp()
+ self.media_store_path = self.mktemp()
+ os.mkdir(self.storage_path)
+ os.mkdir(self.media_store_path)
+
+ config = self.default_config()
+ config["media_store_path"] = self.media_store_path
+ config["thumbnail_requirements"] = {}
+ config["max_image_pixels"] = 2000000
+
+ provider_config = {
+ "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+ "store_local": True,
+ "store_synchronous": False,
+ "store_remote": True,
+ "config": {"directory": self.storage_path},
+ }
+ config["media_storage_providers"] = [provider_config]
+
+ hs = self.setup_test_homeserver(config=config, http_client=client)
+
+ return hs
+
+ def test_quarantine_media_requires_admin(self):
+ self.register_user("nonadmin", "pass", admin=False)
+ non_admin_user_tok = self.login("nonadmin", "pass")
+
+ # Attempt quarantine media APIs as non-admin
+ url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
+ request, channel = self.make_request(
+ "POST", url.encode("ascii"), access_token=non_admin_user_tok,
+ )
+ self.render(request)
+
+ # Expect a forbidden error
+ self.assertEqual(
+ 403,
+ int(channel.result["code"]),
+ msg="Expected forbidden on quarantining media as a non-admin",
+ )
+
+ # And the roomID/userID endpoint
+ url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
+ request, channel = self.make_request(
+ "POST", url.encode("ascii"), access_token=non_admin_user_tok,
+ )
+ self.render(request)
+
+ # Expect a forbidden error
+ self.assertEqual(
+ 403,
+ int(channel.result["code"]),
+ msg="Expected forbidden on quarantining media as a non-admin",
+ )
+
+ def test_quarantine_media_by_id(self):
+ self.register_user("id_admin", "pass", admin=True)
+ admin_user_tok = self.login("id_admin", "pass")
+
+ self.register_user("id_nonadmin", "pass", admin=False)
+ non_admin_user_tok = self.login("id_nonadmin", "pass")
+
+ # Upload some media into the room
+ response = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=admin_user_tok
+ )
+
+ # Extract media ID from the response
+ server_name_and_media_id = response["content_uri"][
+ 6:
+ ] # Cut off the 'mxc://' bit
+ server_name, media_id = server_name_and_media_id.split("/")
+
+ # Attempt to access the media
+ request, channel = self.make_request(
+ "GET",
+ server_name_and_media_id,
+ shorthand=False,
+ access_token=non_admin_user_tok,
+ )
+ request.render(self.download_resource)
+ self.pump(1.0)
+
+ # Should be successful
+ self.assertEqual(200, int(channel.code), msg=channel.result["body"])
+
+ # Quarantine the media
+ url = "/_synapse/admin/v1/media/quarantine/%s/%s" % (
+ urllib.parse.quote(server_name),
+ urllib.parse.quote(media_id),
+ )
+ request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
+ self.render(request)
+ self.pump(1.0)
+ self.assertEqual(200, int(channel.code), msg=channel.result["body"])
+
+ # Attempt to access the media
+ request, channel = self.make_request(
+ "GET",
+ server_name_and_media_id,
+ shorthand=False,
+ access_token=admin_user_tok,
+ )
+ request.render(self.download_resource)
+ self.pump(1.0)
+
+ # Should be quarantined
+ self.assertEqual(
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_name_and_media_id
+ ),
+ )
+
+ def test_quarantine_all_media_in_room(self):
+ self.register_user("room_admin", "pass", admin=True)
+ admin_user_tok = self.login("room_admin", "pass")
+
+ non_admin_user = self.register_user("room_nonadmin", "pass", admin=False)
+ non_admin_user_tok = self.login("room_nonadmin", "pass")
+
+ room_id = self.helper.create_room_as(non_admin_user, tok=admin_user_tok)
+ self.helper.join(room_id, non_admin_user, tok=non_admin_user_tok)
+
+ # Upload some media
+ response_1 = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=non_admin_user_tok
+ )
+ response_2 = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=non_admin_user_tok
+ )
+
+ # Extract mxcs
+ mxc_1 = response_1["content_uri"]
+ mxc_2 = response_2["content_uri"]
+
+ # Send it into the room
+ self.helper.send_event(
+ room_id,
+ "m.room.message",
+ content={"body": "image-1", "msgtype": "m.image", "url": mxc_1},
+ txn_id="111",
+ tok=non_admin_user_tok,
+ )
+ self.helper.send_event(
+ room_id,
+ "m.room.message",
+ content={"body": "image-2", "msgtype": "m.image", "url": mxc_2},
+ txn_id="222",
+ tok=non_admin_user_tok,
+ )
+
+ # Quarantine all media in the room
+ url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
+ room_id
+ )
+ request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
+ self.render(request)
+ self.pump(1.0)
+ self.assertEqual(200, int(channel.code), msg=channel.result["body"])
+ self.assertEqual(
+ json.loads(channel.result["body"].decode("utf-8")),
+ {"num_quarantined": 2},
+ "Expected 2 quarantined items",
+ )
+
+ # Convert mxc URLs to server/media_id strings
+ server_and_media_id_1 = mxc_1[6:]
+ server_and_media_id_2 = mxc_2[6:]
+
+ # Test that we cannot download any of the media anymore
+ request, channel = self.make_request(
+ "GET",
+ server_and_media_id_1,
+ shorthand=False,
+ access_token=non_admin_user_tok,
+ )
+ request.render(self.download_resource)
+ self.pump(1.0)
+
+ # Should be quarantined
+ self.assertEqual(
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_and_media_id_1
+ ),
+ )
+
+ request, channel = self.make_request(
+ "GET",
+ server_and_media_id_2,
+ shorthand=False,
+ access_token=non_admin_user_tok,
+ )
+ request.render(self.download_resource)
+ self.pump(1.0)
+
+ # Should be quarantined
+ self.assertEqual(
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_and_media_id_2
+ ),
+ )
+
+ def test_quarantine_all_media_by_user(self):
+ self.register_user("user_admin", "pass", admin=True)
+ admin_user_tok = self.login("user_admin", "pass")
+
+ non_admin_user = self.register_user("user_nonadmin", "pass", admin=False)
+ non_admin_user_tok = self.login("user_nonadmin", "pass")
+
+ # Upload some media
+ response_1 = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=non_admin_user_tok
+ )
+ response_2 = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=non_admin_user_tok
+ )
+
+ # Extract media IDs
+ server_and_media_id_1 = response_1["content_uri"][6:]
+ server_and_media_id_2 = response_2["content_uri"][6:]
+
+ # Quarantine all media by this user
+ url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
+ non_admin_user
+ )
+ request, channel = self.make_request(
+ "POST", url.encode("ascii"), access_token=admin_user_tok,
+ )
+ self.render(request)
+ self.pump(1.0)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ json.loads(channel.result["body"].decode("utf-8")),
+ {"num_quarantined": 2},
+ "Expected 2 quarantined items",
+ )
+
+ # Attempt to access each piece of media
+ request, channel = self.make_request(
+ "GET",
+ server_and_media_id_1,
+ shorthand=False,
+ access_token=non_admin_user_tok,
+ )
+ request.render(self.download_resource)
+ self.pump(1.0)
+
+ # Should be quarantined
+ self.assertEqual(
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_and_media_id_1,
+ ),
+ )
+
+ # Attempt to access each piece of media
+ request, channel = self.make_request(
+ "GET",
+ server_and_media_id_2,
+ shorthand=False,
+ access_token=non_admin_user_tok,
+ )
+ request.render(self.download_resource)
+ self.pump(1.0)
+
+ # Should be quarantined
+ self.assertEqual(
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_and_media_id_2
+ ),
+ )
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index e7417b3d14..873d5ef99c 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -21,6 +21,8 @@ import time
import attr
+from twisted.web.resource import Resource
+
from synapse.api.constants import Membership
from tests.server import make_request, render
@@ -160,3 +162,38 @@ class RestHelper(object):
)
return channel.json_body
+
+ def upload_media(
+ self,
+ resource: Resource,
+ image_data: bytes,
+ tok: str,
+ filename: str = "test.png",
+ expect_code: int = 200,
+ ) -> dict:
+ """Upload a piece of test media to the media repo
+ Args:
+ resource: The resource that will handle the upload request
+ image_data: The image data to upload
+ tok: The user token to use during the upload
+ filename: The filename of the media to be uploaded
+ expect_code: The return code to expect from attempting to upload the media
+ """
+ image_length = len(image_data)
+ path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
+ request, channel = make_request(
+ self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok
+ )
+ request.requestHeaders.addRawHeader(
+ b"Content-Length", str(image_length).encode("UTF-8")
+ )
+ request.render(resource)
+ self.hs.get_reactor().pump([100])
+
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ int(channel.result["code"]),
+ channel.result["body"],
+ )
+
+ return channel.json_body
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 0f51895b81..c3facc00eb 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -285,7 +285,9 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
)
# Make sure the invite is here.
- pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
+ pending_invites = self.get_success(
+ store.get_invited_rooms_for_local_user(invitee_id)
+ )
self.assertEqual(len(pending_invites), 1, pending_invites)
self.assertEqual(pending_invites[0].room_id, room_id, pending_invites)
@@ -293,12 +295,16 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.deactivate(invitee_id, invitee_tok)
# Check that the invite isn't there anymore.
- pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
+ pending_invites = self.get_success(
+ store.get_invited_rooms_for_local_user(invitee_id)
+ )
self.assertEqual(len(pending_invites), 0, pending_invites)
# Check that the membership of @invitee:test in the room is now "leave".
memberships = self.get_success(
- store.get_rooms_for_user_where_membership_is(invitee_id, [Membership.LEAVE])
+ store.get_rooms_for_local_user_where_membership_is(
+ invitee_id, [Membership.LEAVE]
+ )
)
self.assertEqual(len(memberships), 1, memberships)
self.assertEqual(memberships[0].room_id, room_id, memberships)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 661c1f88b9..9c13a13786 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -15,8 +15,6 @@
# limitations under the License.
import json
-from mock import Mock
-
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes
from synapse.rest.client.v1 import login, room
@@ -36,13 +34,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
]
- def make_homeserver(self, reactor, clock):
-
- hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock()
- )
- return hs
-
def test_sync_argless(self):
request, channel = self.make_request("GET", "/sync")
self.render(request)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 7840f63fe3..00df0ea68e 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -57,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
rooms_for_user = self.get_success(
- self.store.get_rooms_for_user_where_membership_is(
+ self.store.get_rooms_for_local_user_where_membership_is(
self.u_alice, [Membership.JOIN]
)
)
diff --git a/tests/test_server.py b/tests/test_server.py
index 98fef21d55..0d57eed268 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -23,8 +23,12 @@ from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
-from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import JsonResource
+from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.http.server import (
+ DirectServeResource,
+ JsonResource,
+ wrap_html_request_handler,
+)
from synapse.http.site import SynapseSite, logger
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock
@@ -164,6 +168,77 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+class WrapHtmlRequestHandlerTests(unittest.TestCase):
+ class TestResource(DirectServeResource):
+ callback = None
+
+ @wrap_html_request_handler
+ async def _async_render_GET(self, request):
+ return await self.callback(request)
+
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ def test_good_response(self):
+ def callback(request):
+ request.write(b"response")
+ request.finish()
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"200")
+ body = channel.result["body"]
+ self.assertEqual(body, b"response")
+
+ def test_redirect_exception(self):
+ """
+ If the callback raises a RedirectException, it is turned into a 30x
+ with the right location.
+ """
+
+ def callback(request, **kwargs):
+ raise RedirectException(b"/look/an/eagle", 301)
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"301")
+ headers = channel.result["headers"]
+ location_headers = [v for k, v in headers if k == b"Location"]
+ self.assertEqual(location_headers, [b"/look/an/eagle"])
+
+ def test_redirect_exception_with_cookie(self):
+ """
+ If the callback raises a RedirectException which sets a cookie, that is
+ returned too
+ """
+
+ def callback(request, **kwargs):
+ e = RedirectException(b"/no/over/there", 304)
+ e.cookies.append(b"session=yespls")
+ raise e
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"304")
+ headers = channel.result["headers"]
+ location_headers = [v for k, v in headers if k == b"Location"]
+ self.assertEqual(location_headers, [b"/no/over/there"])
+ cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
+ self.assertEqual(cookies_headers, [b"session=yespls"])
+
+
class SiteTestCase(unittest.HomeserverTestCase):
def test_lose_connection(self):
"""
|