summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_register.py2
-rw-r--r--tests/handlers/test_sync.py4
-rw-r--r--tests/replication/slave/storage/test_events.py4
-rw-r--r--tests/rest/admin/test_admin.py341
-rw-r--r--tests/rest/client/v1/utils.py37
-rw-r--r--tests/rest/client/v2_alpha/test_account.py12
-rw-r--r--tests/rest/client/v2_alpha/test_sync.py9
-rw-r--r--tests/storage/test_roommember.py2
-rw-r--r--tests/test_server.py79
9 files changed, 469 insertions, 21 deletions
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 1e9ba3a201..e2915eb7b1 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -269,8 +269,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
               one will be randomly generated.
         Returns:
             A tuple of (user_id, access_token).
-        Raises:
-            RegistrationError if there was a problem registering.
         """
         if localpart is None:
             raise SynapseError(400, "Request must include user id")
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 758ee071a5..4cbe9784ed 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -32,8 +32,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
 
     def test_wait_for_sync_for_user_auth_blocking(self):
 
-        user_id1 = "@user1:server"
-        user_id2 = "@user2:server"
+        user_id1 = "@user1:test"
+        user_id2 = "@user2:test"
         sync_config = self._generate_sync_config(user_id1)
 
         self.reactor.advance(100)  # So we get not 0 time
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index b68e9fe082..b1b037006d 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -115,13 +115,13 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
     def test_invites(self):
         self.persist(type="m.room.create", key="", creator=USER_ID)
-        self.check("get_invited_rooms_for_user", [USER_ID_2], [])
+        self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
         event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
 
         self.replicate()
 
         self.check(
-            "get_invited_rooms_for_user",
+            "get_invited_rooms_for_local_user",
             [USER_ID_2],
             [
                 RoomsForUser(
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 6ceb483aa8..7a7e898843 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -14,11 +14,17 @@
 # limitations under the License.
 
 import json
+import os
+import urllib.parse
+from binascii import unhexlify
 
 from mock import Mock
 
+from twisted.internet.defer import Deferred
+
 import synapse.rest.admin
 from synapse.http.server import JsonResource
+from synapse.logging.context import make_deferred_yieldable
 from synapse.rest.admin import VersionServlet
 from synapse.rest.client.v1 import events, login, room
 from synapse.rest.client.v2_alpha import groups
@@ -346,3 +352,338 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
             self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
 
     test_purge_room.skip = "Disabled because it's currently broken"
+
+
+class QuarantineMediaTestCase(unittest.HomeserverTestCase):
+    """Test /quarantine_media admin API.
+    """
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        synapse.rest.admin.register_servlets_for_media_repo,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.hs = hs
+
+        # Allow for uploading and downloading to/from the media repo
+        self.media_repo = hs.get_media_repository_resource()
+        self.download_resource = self.media_repo.children[b"download"]
+        self.upload_resource = self.media_repo.children[b"upload"]
+        self.image_data = unhexlify(
+            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+            b"0000001f15c4890000000a49444154789c63000100000500010d"
+            b"0a2db40000000049454e44ae426082"
+        )
+
+    def make_homeserver(self, reactor, clock):
+
+        self.fetches = []
+
+        def get_file(destination, path, output_stream, args=None, max_size=None):
+            """
+            Returns tuple[int,dict,str,int] of file length, response headers,
+            absolute URI, and response code.
+            """
+
+            def write_to(r):
+                data, response = r
+                output_stream.write(data)
+                return response
+
+            d = Deferred()
+            d.addCallback(write_to)
+            self.fetches.append((d, destination, path, args))
+            return make_deferred_yieldable(d)
+
+        client = Mock()
+        client.get_file = get_file
+
+        self.storage_path = self.mktemp()
+        self.media_store_path = self.mktemp()
+        os.mkdir(self.storage_path)
+        os.mkdir(self.media_store_path)
+
+        config = self.default_config()
+        config["media_store_path"] = self.media_store_path
+        config["thumbnail_requirements"] = {}
+        config["max_image_pixels"] = 2000000
+
+        provider_config = {
+            "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+            "store_local": True,
+            "store_synchronous": False,
+            "store_remote": True,
+            "config": {"directory": self.storage_path},
+        }
+        config["media_storage_providers"] = [provider_config]
+
+        hs = self.setup_test_homeserver(config=config, http_client=client)
+
+        return hs
+
+    def test_quarantine_media_requires_admin(self):
+        self.register_user("nonadmin", "pass", admin=False)
+        non_admin_user_tok = self.login("nonadmin", "pass")
+
+        # Attempt quarantine media APIs as non-admin
+        url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
+        request, channel = self.make_request(
+            "POST", url.encode("ascii"), access_token=non_admin_user_tok,
+        )
+        self.render(request)
+
+        # Expect a forbidden error
+        self.assertEqual(
+            403,
+            int(channel.result["code"]),
+            msg="Expected forbidden on quarantining media as a non-admin",
+        )
+
+        # And the roomID/userID endpoint
+        url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
+        request, channel = self.make_request(
+            "POST", url.encode("ascii"), access_token=non_admin_user_tok,
+        )
+        self.render(request)
+
+        # Expect a forbidden error
+        self.assertEqual(
+            403,
+            int(channel.result["code"]),
+            msg="Expected forbidden on quarantining media as a non-admin",
+        )
+
+    def test_quarantine_media_by_id(self):
+        self.register_user("id_admin", "pass", admin=True)
+        admin_user_tok = self.login("id_admin", "pass")
+
+        self.register_user("id_nonadmin", "pass", admin=False)
+        non_admin_user_tok = self.login("id_nonadmin", "pass")
+
+        # Upload some media into the room
+        response = self.helper.upload_media(
+            self.upload_resource, self.image_data, tok=admin_user_tok
+        )
+
+        # Extract media ID from the response
+        server_name_and_media_id = response["content_uri"][
+            6:
+        ]  # Cut off the 'mxc://' bit
+        server_name, media_id = server_name_and_media_id.split("/")
+
+        # Attempt to access the media
+        request, channel = self.make_request(
+            "GET",
+            server_name_and_media_id,
+            shorthand=False,
+            access_token=non_admin_user_tok,
+        )
+        request.render(self.download_resource)
+        self.pump(1.0)
+
+        # Should be successful
+        self.assertEqual(200, int(channel.code), msg=channel.result["body"])
+
+        # Quarantine the media
+        url = "/_synapse/admin/v1/media/quarantine/%s/%s" % (
+            urllib.parse.quote(server_name),
+            urllib.parse.quote(media_id),
+        )
+        request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
+        self.render(request)
+        self.pump(1.0)
+        self.assertEqual(200, int(channel.code), msg=channel.result["body"])
+
+        # Attempt to access the media
+        request, channel = self.make_request(
+            "GET",
+            server_name_and_media_id,
+            shorthand=False,
+            access_token=admin_user_tok,
+        )
+        request.render(self.download_resource)
+        self.pump(1.0)
+
+        # Should be quarantined
+        self.assertEqual(
+            404,
+            int(channel.code),
+            msg=(
+                "Expected to receive a 404 on accessing quarantined media: %s"
+                % server_name_and_media_id
+            ),
+        )
+
+    def test_quarantine_all_media_in_room(self):
+        self.register_user("room_admin", "pass", admin=True)
+        admin_user_tok = self.login("room_admin", "pass")
+
+        non_admin_user = self.register_user("room_nonadmin", "pass", admin=False)
+        non_admin_user_tok = self.login("room_nonadmin", "pass")
+
+        room_id = self.helper.create_room_as(non_admin_user, tok=admin_user_tok)
+        self.helper.join(room_id, non_admin_user, tok=non_admin_user_tok)
+
+        # Upload some media
+        response_1 = self.helper.upload_media(
+            self.upload_resource, self.image_data, tok=non_admin_user_tok
+        )
+        response_2 = self.helper.upload_media(
+            self.upload_resource, self.image_data, tok=non_admin_user_tok
+        )
+
+        # Extract mxcs
+        mxc_1 = response_1["content_uri"]
+        mxc_2 = response_2["content_uri"]
+
+        # Send it into the room
+        self.helper.send_event(
+            room_id,
+            "m.room.message",
+            content={"body": "image-1", "msgtype": "m.image", "url": mxc_1},
+            txn_id="111",
+            tok=non_admin_user_tok,
+        )
+        self.helper.send_event(
+            room_id,
+            "m.room.message",
+            content={"body": "image-2", "msgtype": "m.image", "url": mxc_2},
+            txn_id="222",
+            tok=non_admin_user_tok,
+        )
+
+        # Quarantine all media in the room
+        url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
+            room_id
+        )
+        request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
+        self.render(request)
+        self.pump(1.0)
+        self.assertEqual(200, int(channel.code), msg=channel.result["body"])
+        self.assertEqual(
+            json.loads(channel.result["body"].decode("utf-8")),
+            {"num_quarantined": 2},
+            "Expected 2 quarantined items",
+        )
+
+        # Convert mxc URLs to server/media_id strings
+        server_and_media_id_1 = mxc_1[6:]
+        server_and_media_id_2 = mxc_2[6:]
+
+        # Test that we cannot download any of the media anymore
+        request, channel = self.make_request(
+            "GET",
+            server_and_media_id_1,
+            shorthand=False,
+            access_token=non_admin_user_tok,
+        )
+        request.render(self.download_resource)
+        self.pump(1.0)
+
+        # Should be quarantined
+        self.assertEqual(
+            404,
+            int(channel.code),
+            msg=(
+                "Expected to receive a 404 on accessing quarantined media: %s"
+                % server_and_media_id_1
+            ),
+        )
+
+        request, channel = self.make_request(
+            "GET",
+            server_and_media_id_2,
+            shorthand=False,
+            access_token=non_admin_user_tok,
+        )
+        request.render(self.download_resource)
+        self.pump(1.0)
+
+        # Should be quarantined
+        self.assertEqual(
+            404,
+            int(channel.code),
+            msg=(
+                "Expected to receive a 404 on accessing quarantined media: %s"
+                % server_and_media_id_2
+            ),
+        )
+
+    def test_quarantine_all_media_by_user(self):
+        self.register_user("user_admin", "pass", admin=True)
+        admin_user_tok = self.login("user_admin", "pass")
+
+        non_admin_user = self.register_user("user_nonadmin", "pass", admin=False)
+        non_admin_user_tok = self.login("user_nonadmin", "pass")
+
+        # Upload some media
+        response_1 = self.helper.upload_media(
+            self.upload_resource, self.image_data, tok=non_admin_user_tok
+        )
+        response_2 = self.helper.upload_media(
+            self.upload_resource, self.image_data, tok=non_admin_user_tok
+        )
+
+        # Extract media IDs
+        server_and_media_id_1 = response_1["content_uri"][6:]
+        server_and_media_id_2 = response_2["content_uri"][6:]
+
+        # Quarantine all media by this user
+        url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
+            non_admin_user
+        )
+        request, channel = self.make_request(
+            "POST", url.encode("ascii"), access_token=admin_user_tok,
+        )
+        self.render(request)
+        self.pump(1.0)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(
+            json.loads(channel.result["body"].decode("utf-8")),
+            {"num_quarantined": 2},
+            "Expected 2 quarantined items",
+        )
+
+        # Attempt to access each piece of media
+        request, channel = self.make_request(
+            "GET",
+            server_and_media_id_1,
+            shorthand=False,
+            access_token=non_admin_user_tok,
+        )
+        request.render(self.download_resource)
+        self.pump(1.0)
+
+        # Should be quarantined
+        self.assertEqual(
+            404,
+            int(channel.code),
+            msg=(
+                "Expected to receive a 404 on accessing quarantined media: %s"
+                % server_and_media_id_1,
+            ),
+        )
+
+        # Attempt to access each piece of media
+        request, channel = self.make_request(
+            "GET",
+            server_and_media_id_2,
+            shorthand=False,
+            access_token=non_admin_user_tok,
+        )
+        request.render(self.download_resource)
+        self.pump(1.0)
+
+        # Should be quarantined
+        self.assertEqual(
+            404,
+            int(channel.code),
+            msg=(
+                "Expected to receive a 404 on accessing quarantined media: %s"
+                % server_and_media_id_2
+            ),
+        )
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index e7417b3d14..873d5ef99c 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -21,6 +21,8 @@ import time
 
 import attr
 
+from twisted.web.resource import Resource
+
 from synapse.api.constants import Membership
 
 from tests.server import make_request, render
@@ -160,3 +162,38 @@ class RestHelper(object):
         )
 
         return channel.json_body
+
+    def upload_media(
+        self,
+        resource: Resource,
+        image_data: bytes,
+        tok: str,
+        filename: str = "test.png",
+        expect_code: int = 200,
+    ) -> dict:
+        """Upload a piece of test media to the media repo
+        Args:
+            resource: The resource that will handle the upload request
+            image_data: The image data to upload
+            tok: The user token to use during the upload
+            filename: The filename of the media to be uploaded
+            expect_code: The return code to expect from attempting to upload the media
+        """
+        image_length = len(image_data)
+        path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
+        request, channel = make_request(
+            self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok
+        )
+        request.requestHeaders.addRawHeader(
+            b"Content-Length", str(image_length).encode("UTF-8")
+        )
+        request.render(resource)
+        self.hs.get_reactor().pump([100])
+
+        assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
+            expect_code,
+            int(channel.result["code"]),
+            channel.result["body"],
+        )
+
+        return channel.json_body
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 0f51895b81..c3facc00eb 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -285,7 +285,9 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
         )
 
         # Make sure the invite is here.
-        pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
+        pending_invites = self.get_success(
+            store.get_invited_rooms_for_local_user(invitee_id)
+        )
         self.assertEqual(len(pending_invites), 1, pending_invites)
         self.assertEqual(pending_invites[0].room_id, room_id, pending_invites)
 
@@ -293,12 +295,16 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
         self.deactivate(invitee_id, invitee_tok)
 
         # Check that the invite isn't there anymore.
-        pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
+        pending_invites = self.get_success(
+            store.get_invited_rooms_for_local_user(invitee_id)
+        )
         self.assertEqual(len(pending_invites), 0, pending_invites)
 
         # Check that the membership of @invitee:test in the room is now "leave".
         memberships = self.get_success(
-            store.get_rooms_for_user_where_membership_is(invitee_id, [Membership.LEAVE])
+            store.get_rooms_for_local_user_where_membership_is(
+                invitee_id, [Membership.LEAVE]
+            )
         )
         self.assertEqual(len(memberships), 1, memberships)
         self.assertEqual(memberships[0].room_id, room_id, memberships)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 661c1f88b9..9c13a13786 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -15,8 +15,6 @@
 # limitations under the License.
 import json
 
-from mock import Mock
-
 import synapse.rest.admin
 from synapse.api.constants import EventContentFields, EventTypes
 from synapse.rest.client.v1 import login, room
@@ -36,13 +34,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
         sync.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
-
-        hs = self.setup_test_homeserver(
-            "red", http_client=None, federation_client=Mock()
-        )
-        return hs
-
     def test_sync_argless(self):
         request, channel = self.make_request("GET", "/sync")
         self.render(request)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 7840f63fe3..00df0ea68e 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -57,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
 
         rooms_for_user = self.get_success(
-            self.store.get_rooms_for_user_where_membership_is(
+            self.store.get_rooms_for_local_user_where_membership_is(
                 self.u_alice, [Membership.JOIN]
             )
         )
diff --git a/tests/test_server.py b/tests/test_server.py
index 98fef21d55..0d57eed268 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -23,8 +23,12 @@ from twisted.test.proto_helpers import AccumulatingProtocol
 from twisted.web.resource import Resource
 from twisted.web.server import NOT_DONE_YET
 
-from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import JsonResource
+from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.http.server import (
+    DirectServeResource,
+    JsonResource,
+    wrap_html_request_handler,
+)
 from synapse.http.site import SynapseSite, logger
 from synapse.logging.context import make_deferred_yieldable
 from synapse.util import Clock
@@ -164,6 +168,77 @@ class JsonResourceTests(unittest.TestCase):
         self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
 
 
+class WrapHtmlRequestHandlerTests(unittest.TestCase):
+    class TestResource(DirectServeResource):
+        callback = None
+
+        @wrap_html_request_handler
+        async def _async_render_GET(self, request):
+            return await self.callback(request)
+
+    def setUp(self):
+        self.reactor = ThreadedMemoryReactorClock()
+
+    def test_good_response(self):
+        def callback(request):
+            request.write(b"response")
+            request.finish()
+
+        res = WrapHtmlRequestHandlerTests.TestResource()
+        res.callback = callback
+
+        request, channel = make_request(self.reactor, b"GET", b"/path")
+        render(request, res, self.reactor)
+
+        self.assertEqual(channel.result["code"], b"200")
+        body = channel.result["body"]
+        self.assertEqual(body, b"response")
+
+    def test_redirect_exception(self):
+        """
+        If the callback raises a RedirectException, it is turned into a 30x
+        with the right location.
+        """
+
+        def callback(request, **kwargs):
+            raise RedirectException(b"/look/an/eagle", 301)
+
+        res = WrapHtmlRequestHandlerTests.TestResource()
+        res.callback = callback
+
+        request, channel = make_request(self.reactor, b"GET", b"/path")
+        render(request, res, self.reactor)
+
+        self.assertEqual(channel.result["code"], b"301")
+        headers = channel.result["headers"]
+        location_headers = [v for k, v in headers if k == b"Location"]
+        self.assertEqual(location_headers, [b"/look/an/eagle"])
+
+    def test_redirect_exception_with_cookie(self):
+        """
+        If the callback raises a RedirectException which sets a cookie, that is
+        returned too
+        """
+
+        def callback(request, **kwargs):
+            e = RedirectException(b"/no/over/there", 304)
+            e.cookies.append(b"session=yespls")
+            raise e
+
+        res = WrapHtmlRequestHandlerTests.TestResource()
+        res.callback = callback
+
+        request, channel = make_request(self.reactor, b"GET", b"/path")
+        render(request, res, self.reactor)
+
+        self.assertEqual(channel.result["code"], b"304")
+        headers = channel.result["headers"]
+        location_headers = [v for k, v in headers if k == b"Location"]
+        self.assertEqual(location_headers, [b"/no/over/there"])
+        cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
+        self.assertEqual(cookies_headers, [b"session=yespls"])
+
+
 class SiteTestCase(unittest.HomeserverTestCase):
     def test_lose_connection(self):
         """