summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py104
-rw-r--r--tests/rest/media/v1/test_media_storage.py145
-rw-r--r--tests/rest/media/v1/test_url_preview.py77
3 files changed, 326 insertions, 0 deletions
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
new file mode 100644
index 0000000000..7fa120a10f
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet.defer import succeed
+
+from synapse.api.constants import LoginType
+from synapse.rest.client.v1 import admin
+from synapse.rest.client.v2_alpha import auth, register
+
+from tests import unittest
+
+
+class FallbackAuthTests(unittest.HomeserverTestCase):
+
+    servlets = [
+        auth.register_servlets,
+        admin.register_servlets,
+        register.register_servlets,
+    ]
+    hijack_auth = False
+
+    def make_homeserver(self, reactor, clock):
+
+        config = self.default_config()
+
+        config.enable_registration_captcha = True
+        config.recaptcha_public_key = "brokencake"
+        config.registrations_require_3pid = []
+
+        hs = self.setup_test_homeserver(config=config)
+        return hs
+
+    def prepare(self, reactor, clock, hs):
+        auth_handler = hs.get_auth_handler()
+
+        self.recaptcha_attempts = []
+
+        def _recaptcha(authdict, clientip):
+            self.recaptcha_attempts.append((authdict, clientip))
+            return succeed(True)
+
+        auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha
+
+    @unittest.INFO
+    def test_fallback_captcha(self):
+
+        request, channel = self.make_request(
+            "POST",
+            "register",
+            {"username": "user", "type": "m.login.password", "password": "bar"},
+        )
+        self.render(request)
+
+        # Returns a 401 as per the spec
+        self.assertEqual(request.code, 401)
+        # Grab the session
+        session = channel.json_body["session"]
+        # Assert our configured public key is being given
+        self.assertEqual(
+            channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
+        )
+
+        request, channel = self.make_request(
+            "GET", "auth/m.login.recaptcha/fallback/web?session=" + session
+        )
+        self.render(request)
+        self.assertEqual(request.code, 200)
+
+        request, channel = self.make_request(
+            "POST",
+            "auth/m.login.recaptcha/fallback/web?session="
+            + session
+            + "&g-recaptcha-response=a",
+        )
+        self.render(request)
+        self.assertEqual(request.code, 200)
+
+        # The recaptcha handler is called with the response given
+        self.assertEqual(len(self.recaptcha_attempts), 1)
+        self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
+
+        # Now we have fufilled the recaptcha fallback step, we can then send a
+        # request to the register API with the session in the authdict.
+        request, channel = self.make_request(
+            "POST", "register", {"auth": {"session": session}}
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # We're given a registered user.
+        self.assertEqual(channel.json_body["user_id"], "@user:test")
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index a86901c2d8..fd131e3454 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -17,15 +17,20 @@
 import os
 import shutil
 import tempfile
+from binascii import unhexlify
 
 from mock import Mock
+from six.moves.urllib import parse
 
 from twisted.internet import defer, reactor
+from twisted.internet.defer import Deferred
 
+from synapse.config.repository import MediaStorageProviderConfig
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.filepath import MediaFilePaths
 from synapse.rest.media.v1.media_storage import MediaStorage
 from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
+from synapse.util.module_loader import load_module
 
 from tests import unittest
 
@@ -83,3 +88,143 @@ class MediaStorageTests(unittest.TestCase):
             body = f.read()
 
         self.assertEqual(test_body, body)
+
+
+class MediaRepoTests(unittest.HomeserverTestCase):
+
+    hijack_auth = True
+    user_id = "@test:user"
+
+    def make_homeserver(self, reactor, clock):
+
+        self.fetches = []
+
+        def get_file(destination, path, output_stream, args=None, max_size=None):
+            """
+            Returns tuple[int,dict,str,int] of file length, response headers,
+            absolute URI, and response code.
+            """
+
+            def write_to(r):
+                data, response = r
+                output_stream.write(data)
+                return response
+
+            d = Deferred()
+            d.addCallback(write_to)
+            self.fetches.append((d, destination, path, args))
+            return d
+
+        client = Mock()
+        client.get_file = get_file
+
+        self.storage_path = self.mktemp()
+        os.mkdir(self.storage_path)
+
+        config = self.default_config()
+        config.media_store_path = self.storage_path
+        config.thumbnail_requirements = {}
+        config.max_image_pixels = 2000000
+
+        provider_config = {
+            "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+            "store_local": True,
+            "store_synchronous": False,
+            "store_remote": True,
+            "config": {"directory": self.storage_path},
+        }
+
+        loaded = list(load_module(provider_config)) + [
+            MediaStorageProviderConfig(False, False, False)
+        ]
+
+        config.media_storage_providers = [loaded]
+
+        hs = self.setup_test_homeserver(config=config, http_client=client)
+
+        return hs
+
+    def prepare(self, reactor, clock, hs):
+
+        self.media_repo = hs.get_media_repository_resource()
+        self.download_resource = self.media_repo.children[b'download']
+
+        # smol png
+        self.end_content = unhexlify(
+            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+            b"0000001f15c4890000000a49444154789c63000100000500010d"
+            b"0a2db40000000049454e44ae426082"
+        )
+
+    def _req(self, content_disposition):
+
+        request, channel = self.make_request(
+            "GET", "example.com/12345", shorthand=False
+        )
+        request.render(self.download_resource)
+        self.pump()
+
+        # We've made one fetch, to example.com, using the media URL, and asking
+        # the other server not to do a remote fetch
+        self.assertEqual(len(self.fetches), 1)
+        self.assertEqual(self.fetches[0][1], "example.com")
+        self.assertEqual(
+            self.fetches[0][2], "/_matrix/media/v1/download/example.com/12345"
+        )
+        self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
+
+        headers = {
+            b"Content-Length": [b"%d" % (len(self.end_content))],
+            b"Content-Type": [b'image/png'],
+        }
+        if content_disposition:
+            headers[b"Content-Disposition"] = [content_disposition]
+
+        self.fetches[0][0].callback(
+            (self.end_content, (len(self.end_content), headers))
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+
+        return channel
+
+    def test_disposition_filename_ascii(self):
+        """
+        If the filename is filename=<ascii> then Synapse will decode it as an
+        ASCII string, and use filename= in the response.
+        """
+        channel = self._req(b"inline; filename=out.png")
+
+        headers = channel.headers
+        self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Disposition"), [b"inline; filename=out.png"]
+        )
+
+    def test_disposition_filenamestar_utf8escaped(self):
+        """
+        If the filename is filename=*utf8''<utf8 escaped> then Synapse will
+        correctly decode it as the UTF-8 string, and use filename* in the
+        response.
+        """
+        filename = parse.quote(u"\u2603".encode('utf8')).encode('ascii')
+        channel = self._req(b"inline; filename*=utf-8''" + filename + b".png")
+
+        headers = channel.headers
+        self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Disposition"),
+            [b"inline; filename*=utf-8''" + filename + b".png"],
+        )
+
+    def test_disposition_none(self):
+        """
+        If there is no filename, one isn't passed on in the Content-Disposition
+        of the request.
+        """
+        channel = self._req(None)
+
+        headers = channel.headers
+        self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+        self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 29579cf091..86c813200a 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -162,3 +162,80 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(
             channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
         )
+
+    def test_non_ascii_preview_httpequiv(self):
+
+        request, channel = self.make_request(
+            "GET", "url_preview?url=matrix.org", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        # We've made one fetch
+        self.assertEqual(len(self.fetches), 1)
+
+        end_content = (
+            b'<html><head>'
+            b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>'
+            b'<meta property="og:title" content="\xe4\xea\xe0" />'
+            b'<meta property="og:description" content="hi" />'
+            b'</head></html>'
+        )
+
+        self.fetches[0][0].callback(
+            (
+                end_content,
+                (
+                    len(end_content),
+                    {
+                        b"Content-Length": [b"%d" % (len(end_content))],
+                        # This charset=utf-8 should be ignored, because the
+                        # document has a meta tag overriding it.
+                        b"Content-Type": [b'text/html; charset="utf8"'],
+                    },
+                    "https://example.com",
+                    200,
+                ),
+            )
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
+
+    def test_non_ascii_preview_content_type(self):
+
+        request, channel = self.make_request(
+            "GET", "url_preview?url=matrix.org", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        # We've made one fetch
+        self.assertEqual(len(self.fetches), 1)
+
+        end_content = (
+            b'<html><head>'
+            b'<meta property="og:title" content="\xe4\xea\xe0" />'
+            b'<meta property="og:description" content="hi" />'
+            b'</head></html>'
+        )
+
+        self.fetches[0][0].callback(
+            (
+                end_content,
+                (
+                    len(end_content),
+                    {
+                        b"Content-Length": [b"%d" % (len(end_content))],
+                        b"Content-Type": [b'text/html; charset="windows-1251"'],
+                    },
+                    "https://example.com",
+                    200,
+                ),
+            )
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")