summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py104
-rw-r--r--tests/rest/media/v1/test_media_storage.py145
-rw-r--r--tests/rest/media/v1/test_url_preview.py77
-rw-r--r--tests/server.py15
-rw-r--r--tests/storage/test_monthly_active_users.py25
-rw-r--r--tests/test_mau.py18
-rw-r--r--tests/test_terms_auth.py2
-rw-r--r--tests/utils.py1
8 files changed, 386 insertions, 1 deletions
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
new file mode 100644
index 0000000000..7fa120a10f
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet.defer import succeed
+
+from synapse.api.constants import LoginType
+from synapse.rest.client.v1 import admin
+from synapse.rest.client.v2_alpha import auth, register
+
+from tests import unittest
+
+
+class FallbackAuthTests(unittest.HomeserverTestCase):
+
+    servlets = [
+        auth.register_servlets,
+        admin.register_servlets,
+        register.register_servlets,
+    ]
+    hijack_auth = False
+
+    def make_homeserver(self, reactor, clock):
+
+        config = self.default_config()
+
+        config.enable_registration_captcha = True
+        config.recaptcha_public_key = "brokencake"
+        config.registrations_require_3pid = []
+
+        hs = self.setup_test_homeserver(config=config)
+        return hs
+
+    def prepare(self, reactor, clock, hs):
+        auth_handler = hs.get_auth_handler()
+
+        self.recaptcha_attempts = []
+
+        def _recaptcha(authdict, clientip):
+            self.recaptcha_attempts.append((authdict, clientip))
+            return succeed(True)
+
+        auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha
+
+    @unittest.INFO
+    def test_fallback_captcha(self):
+
+        request, channel = self.make_request(
+            "POST",
+            "register",
+            {"username": "user", "type": "m.login.password", "password": "bar"},
+        )
+        self.render(request)
+
+        # Returns a 401 as per the spec
+        self.assertEqual(request.code, 401)
+        # Grab the session
+        session = channel.json_body["session"]
+        # Assert our configured public key is being given
+        self.assertEqual(
+            channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
+        )
+
+        request, channel = self.make_request(
+            "GET", "auth/m.login.recaptcha/fallback/web?session=" + session
+        )
+        self.render(request)
+        self.assertEqual(request.code, 200)
+
+        request, channel = self.make_request(
+            "POST",
+            "auth/m.login.recaptcha/fallback/web?session="
+            + session
+            + "&g-recaptcha-response=a",
+        )
+        self.render(request)
+        self.assertEqual(request.code, 200)
+
+        # The recaptcha handler is called with the response given
+        self.assertEqual(len(self.recaptcha_attempts), 1)
+        self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
+
+        # Now we have fufilled the recaptcha fallback step, we can then send a
+        # request to the register API with the session in the authdict.
+        request, channel = self.make_request(
+            "POST", "register", {"auth": {"session": session}}
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # We're given a registered user.
+        self.assertEqual(channel.json_body["user_id"], "@user:test")
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index a86901c2d8..fd131e3454 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -17,15 +17,20 @@
 import os
 import shutil
 import tempfile
+from binascii import unhexlify
 
 from mock import Mock
+from six.moves.urllib import parse
 
 from twisted.internet import defer, reactor
+from twisted.internet.defer import Deferred
 
+from synapse.config.repository import MediaStorageProviderConfig
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.filepath import MediaFilePaths
 from synapse.rest.media.v1.media_storage import MediaStorage
 from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
+from synapse.util.module_loader import load_module
 
 from tests import unittest
 
@@ -83,3 +88,143 @@ class MediaStorageTests(unittest.TestCase):
             body = f.read()
 
         self.assertEqual(test_body, body)
+
+
+class MediaRepoTests(unittest.HomeserverTestCase):
+
+    hijack_auth = True
+    user_id = "@test:user"
+
+    def make_homeserver(self, reactor, clock):
+
+        self.fetches = []
+
+        def get_file(destination, path, output_stream, args=None, max_size=None):
+            """
+            Returns tuple[int,dict,str,int] of file length, response headers,
+            absolute URI, and response code.
+            """
+
+            def write_to(r):
+                data, response = r
+                output_stream.write(data)
+                return response
+
+            d = Deferred()
+            d.addCallback(write_to)
+            self.fetches.append((d, destination, path, args))
+            return d
+
+        client = Mock()
+        client.get_file = get_file
+
+        self.storage_path = self.mktemp()
+        os.mkdir(self.storage_path)
+
+        config = self.default_config()
+        config.media_store_path = self.storage_path
+        config.thumbnail_requirements = {}
+        config.max_image_pixels = 2000000
+
+        provider_config = {
+            "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+            "store_local": True,
+            "store_synchronous": False,
+            "store_remote": True,
+            "config": {"directory": self.storage_path},
+        }
+
+        loaded = list(load_module(provider_config)) + [
+            MediaStorageProviderConfig(False, False, False)
+        ]
+
+        config.media_storage_providers = [loaded]
+
+        hs = self.setup_test_homeserver(config=config, http_client=client)
+
+        return hs
+
+    def prepare(self, reactor, clock, hs):
+
+        self.media_repo = hs.get_media_repository_resource()
+        self.download_resource = self.media_repo.children[b'download']
+
+        # smol png
+        self.end_content = unhexlify(
+            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+            b"0000001f15c4890000000a49444154789c63000100000500010d"
+            b"0a2db40000000049454e44ae426082"
+        )
+
+    def _req(self, content_disposition):
+
+        request, channel = self.make_request(
+            "GET", "example.com/12345", shorthand=False
+        )
+        request.render(self.download_resource)
+        self.pump()
+
+        # We've made one fetch, to example.com, using the media URL, and asking
+        # the other server not to do a remote fetch
+        self.assertEqual(len(self.fetches), 1)
+        self.assertEqual(self.fetches[0][1], "example.com")
+        self.assertEqual(
+            self.fetches[0][2], "/_matrix/media/v1/download/example.com/12345"
+        )
+        self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
+
+        headers = {
+            b"Content-Length": [b"%d" % (len(self.end_content))],
+            b"Content-Type": [b'image/png'],
+        }
+        if content_disposition:
+            headers[b"Content-Disposition"] = [content_disposition]
+
+        self.fetches[0][0].callback(
+            (self.end_content, (len(self.end_content), headers))
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+
+        return channel
+
+    def test_disposition_filename_ascii(self):
+        """
+        If the filename is filename=<ascii> then Synapse will decode it as an
+        ASCII string, and use filename= in the response.
+        """
+        channel = self._req(b"inline; filename=out.png")
+
+        headers = channel.headers
+        self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Disposition"), [b"inline; filename=out.png"]
+        )
+
+    def test_disposition_filenamestar_utf8escaped(self):
+        """
+        If the filename is filename=*utf8''<utf8 escaped> then Synapse will
+        correctly decode it as the UTF-8 string, and use filename* in the
+        response.
+        """
+        filename = parse.quote(u"\u2603".encode('utf8')).encode('ascii')
+        channel = self._req(b"inline; filename*=utf-8''" + filename + b".png")
+
+        headers = channel.headers
+        self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Disposition"),
+            [b"inline; filename*=utf-8''" + filename + b".png"],
+        )
+
+    def test_disposition_none(self):
+        """
+        If there is no filename, one isn't passed on in the Content-Disposition
+        of the request.
+        """
+        channel = self._req(None)
+
+        headers = channel.headers
+        self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+        self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 29579cf091..86c813200a 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -162,3 +162,80 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(
             channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
         )
+
+    def test_non_ascii_preview_httpequiv(self):
+
+        request, channel = self.make_request(
+            "GET", "url_preview?url=matrix.org", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        # We've made one fetch
+        self.assertEqual(len(self.fetches), 1)
+
+        end_content = (
+            b'<html><head>'
+            b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>'
+            b'<meta property="og:title" content="\xe4\xea\xe0" />'
+            b'<meta property="og:description" content="hi" />'
+            b'</head></html>'
+        )
+
+        self.fetches[0][0].callback(
+            (
+                end_content,
+                (
+                    len(end_content),
+                    {
+                        b"Content-Length": [b"%d" % (len(end_content))],
+                        # This charset=utf-8 should be ignored, because the
+                        # document has a meta tag overriding it.
+                        b"Content-Type": [b'text/html; charset="utf8"'],
+                    },
+                    "https://example.com",
+                    200,
+                ),
+            )
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
+
+    def test_non_ascii_preview_content_type(self):
+
+        request, channel = self.make_request(
+            "GET", "url_preview?url=matrix.org", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        # We've made one fetch
+        self.assertEqual(len(self.fetches), 1)
+
+        end_content = (
+            b'<html><head>'
+            b'<meta property="og:title" content="\xe4\xea\xe0" />'
+            b'<meta property="og:description" content="hi" />'
+            b'</head></html>'
+        )
+
+        self.fetches[0][0].callback(
+            (
+                end_content,
+                (
+                    len(end_content),
+                    {
+                        b"Content-Length": [b"%d" % (len(end_content))],
+                        b"Content-Type": [b'text/html; charset="windows-1251"'],
+                    },
+                    "https://example.com",
+                    200,
+                ),
+            )
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
diff --git a/tests/server.py b/tests/server.py
index 7919a1f124..ceec2f2d4e 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -14,6 +14,8 @@ from twisted.internet.error import DNSLookupError
 from twisted.internet.interfaces import IReactorPluggableNameResolver
 from twisted.python.failure import Failure
 from twisted.test.proto_helpers import MemoryReactorClock
+from twisted.web.http import unquote
+from twisted.web.http_headers import Headers
 
 from synapse.http.site import SynapseRequest
 from synapse.util import Clock
@@ -50,6 +52,15 @@ class FakeChannel(object):
             raise Exception("No result yet.")
         return int(self.result["code"])
 
+    @property
+    def headers(self):
+        if not self.result:
+            raise Exception("No result yet.")
+        h = Headers()
+        for i in self.result["headers"]:
+            h.addRawHeader(*i)
+        return h
+
     def writeHeaders(self, version, code, reason, headers):
         self.result["version"] = version
         self.result["code"] = code
@@ -152,6 +163,9 @@ def make_request(
         path = b"/_matrix/client/r0/" + path
         path = path.replace(b"//", b"/")
 
+    if not path.startswith(b"/"):
+        path = b"/" + path
+
     if isinstance(content, text_type):
         content = content.encode('utf8')
 
@@ -161,6 +175,7 @@ def make_request(
     req = request(site, channel)
     req.process = lambda: b""
     req.content = BytesIO(content)
+    req.postpath = list(map(unquote, path[1:].split(b'/')))
 
     if access_token:
         req.requestHeaders.addRawHeader(
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 832e379a83..8664bc3d54 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -220,3 +220,28 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
         self.store.user_add_threepid(user2, "email", user2_email, now, now)
         count = self.store.get_registered_reserved_users_count()
         self.assertEquals(self.get_success(count), len(threepids))
+
+    def test_track_monthly_users_without_cap(self):
+        self.hs.config.limit_usage_by_mau = False
+        self.hs.config.mau_stats_only = True
+        self.hs.config.max_mau_value = 1  # should not matter
+
+        count = self.store.get_monthly_active_count()
+        self.assertEqual(0, self.get_success(count))
+
+        self.store.upsert_monthly_active_user("@user1:server")
+        self.store.upsert_monthly_active_user("@user2:server")
+        self.pump()
+
+        count = self.store.get_monthly_active_count()
+        self.assertEqual(2, self.get_success(count))
+
+    def test_no_users_when_not_tracking(self):
+        self.hs.config.limit_usage_by_mau = False
+        self.hs.config.mau_stats_only = False
+        self.store.upsert_monthly_active_user = Mock()
+
+        self.store.populate_monthly_active_users("@user:sever")
+        self.pump()
+
+        self.store.upsert_monthly_active_user.assert_not_called()
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 0afdeb0818..04f95c942f 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -171,6 +171,24 @@ class TestMauLimit(unittest.HomeserverTestCase):
         self.assertEqual(e.code, 403)
         self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
+    def test_tracked_but_not_limited(self):
+        self.hs.config.max_mau_value = 1  # should not matter
+        self.hs.config.limit_usage_by_mau = False
+        self.hs.config.mau_stats_only = True
+
+        # Simply being able to create 2 users indicates that the
+        # limit was not reached.
+        token1 = self.create_user("kermit1")
+        self.do_sync_for_user(token1)
+        token2 = self.create_user("kermit2")
+        self.do_sync_for_user(token2)
+
+        # We do want to verify that the number of tracked users
+        # matches what we want though
+        count = self.store.get_monthly_active_count()
+        self.reactor.advance(100)
+        self.assertEqual(2, self.successResultOf(count))
+
     def create_user(self, localpart):
         request_data = json.dumps(
             {
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 9ecc3ef14f..0968e86a7b 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -43,7 +43,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
     def test_ui_auth(self):
         self.hs.config.user_consent_at_registration = True
         self.hs.config.user_consent_policy_name = "My Cool Privacy Policy"
-        self.hs.config.public_baseurl = "https://example.org"
+        self.hs.config.public_baseurl = "https://example.org/"
         self.hs.config.user_consent_version = "1.0"
 
         # Do a UI auth request
diff --git a/tests/utils.py b/tests/utils.py
index 67ab916f30..52ab762010 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -134,6 +134,7 @@ def default_config(name):
     config.hs_disabled_limit_type = ""
     config.max_mau_value = 50
     config.mau_trial_days = 0
+    config.mau_stats_only = False
     config.mau_limits_reserved_threepids = []
     config.admin_contact = None
     config.rc_messages_per_second = 10000