diff --git a/tests/__init__.py b/tests/__init__.py
index 9d9ca22829..d3181f9403 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +16,9 @@
from twisted.trial import util
-from tests import utils
+import tests.patch_inline_callbacks
+
+# attempt to do the patch before we load any synapse code
+tests.patch_inline_callbacks.do_patch()
util.DEFAULT_TIMEOUT_DURATION = 10
-utils.setupdb()
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 379e9c4ab1..69dc40428b 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -50,6 +50,8 @@ class AuthTestCase(unittest.TestCase):
# this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None)
+ self.store.is_support_user = Mock(return_value=defer.succeed(False))
+
@defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 8299dc72c8..d643bec887 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -63,6 +63,14 @@ class KeyringTestCase(unittest.TestCase):
keys = self.mock_perspective_server.get_verify_keys()
self.hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
+ def assert_sentinel_context(self):
+ if LoggingContext.current_context() != LoggingContext.sentinel:
+ self.fail(
+ "Expected sentinel context but got %s" % (
+ LoggingContext.current_context(),
+ )
+ )
+
def check_context(self, _, expected):
self.assertEquals(
getattr(LoggingContext.current_context(), "request", None), expected
@@ -70,8 +78,6 @@ class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_wait_for_previous_lookups(self):
- sentinel_context = LoggingContext.current_context()
-
kr = keyring.Keyring(self.hs)
lookup_1_deferred = defer.Deferred()
@@ -99,8 +105,10 @@ class KeyringTestCase(unittest.TestCase):
["server1"], {"server1": lookup_2_deferred}
)
self.assertFalse(wait_2_deferred.called)
+
# ... so we should have reset the LoggingContext.
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assert_sentinel_context()
+
wait_2_deferred.addBoth(self.check_context, "two")
# let the first lookup complete (in the sentinel context)
@@ -198,8 +206,6 @@ class KeyringTestCase(unittest.TestCase):
json1 = {}
signedjson.sign.sign_json(json1, "server9", key1)
- sentinel_context = LoggingContext.current_context()
-
with LoggingContext("one") as context_one:
context_one.request = "one"
@@ -213,7 +219,7 @@ class KeyringTestCase(unittest.TestCase):
defer = kr.verify_json_for_server("server9", json1)
self.assertFalse(defer.called)
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assert_sentinel_context()
yield defer
self.assertIs(LoggingContext.current_context(), context_one)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 3e9a190727..eb70e1daa6 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,7 +17,8 @@ from mock import Mock
from twisted.internet import defer
-from synapse.api.errors import ResourceLimitError
+from synapse.api.constants import UserTypes
+from synapse.api.errors import ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
from synapse.types import RoomAlias, UserID, create_requester
@@ -64,6 +65,7 @@ class RegistrationTestCase(unittest.TestCase):
requester, frank.localpart, "Frankie"
)
self.assertEquals(result_user_id, user_id)
+ self.assertTrue(result_token is not None)
self.assertEquals(result_token, 'secret')
@defer.inlineCallbacks
@@ -82,7 +84,7 @@ class RegistrationTestCase(unittest.TestCase):
requester, local_part, None
)
self.assertEquals(result_user_id, user_id)
- self.assertEquals(result_token, 'secret')
+ self.assertTrue(result_token is not None)
@defer.inlineCallbacks
def test_mau_limits_when_disabled(self):
@@ -130,27 +132,11 @@ class RegistrationTestCase(unittest.TestCase):
yield self.handler.register(localpart="local_part")
@defer.inlineCallbacks
- def test_register_saml2_mau_blocked(self):
- self.hs.config.limit_usage_by_mau = True
- self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.lots_of_users)
- )
- with self.assertRaises(ResourceLimitError):
- yield self.handler.register_saml2(localpart="local_part")
-
- self.store.get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
- )
- with self.assertRaises(ResourceLimitError):
- yield self.handler.register_saml2(localpart="local_part")
-
- @defer.inlineCallbacks
def test_auto_create_auto_join_rooms(self):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
res = yield self.handler.register(localpart='jeff')
rooms = yield self.store.get_rooms_for_user(res[0])
-
directory_handler = self.hs.get_handlers().directory_handler
room_alias = RoomAlias.from_string(room_alias_str)
room_id = yield directory_handler.get_association(room_alias)
@@ -184,3 +170,38 @@ class RegistrationTestCase(unittest.TestCase):
res = yield self.handler.register(localpart='jeff')
rooms = yield self.store.get_rooms_for_user(res[0])
self.assertEqual(len(rooms), 0)
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_rooms_when_support_user_exists(self):
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+
+ self.store.is_support_user = Mock(return_value=True)
+ res = yield self.handler.register(localpart='support')
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ self.assertEqual(len(rooms), 0)
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ with self.assertRaises(SynapseError):
+ yield directory_handler.get_association(room_alias)
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_where_no_consent(self):
+ self.hs.config.user_consent_at_registration = True
+ self.hs.config.block_events_without_consent_error = "Error"
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+ res = yield self.handler.register(localpart='jeff')
+ yield self.handler.post_consent_actions(res[0])
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ self.assertEqual(len(rooms), 0)
+
+ @defer.inlineCallbacks
+ def test_register_support_user(self):
+ res = yield self.handler.register(localpart='user', user_type=UserTypes.SUPPORT)
+ self.assertTrue(self.store.is_support_user(res[0]))
+
+ @defer.inlineCallbacks
+ def test_register_not_support_user(self):
+ res = yield self.handler.register(localpart='user')
+ self.assertFalse(self.store.is_support_user(res[0]))
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
new file mode 100644
index 0000000000..11f2bae698
--- /dev/null
+++ b/tests/handlers/test_user_directory.py
@@ -0,0 +1,91 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import UserTypes
+from synapse.handlers.user_directory import UserDirectoryHandler
+from synapse.storage.roommember import ProfileInfo
+
+from tests import unittest
+from tests.utils import setup_test_homeserver
+
+
+class UserDirectoryHandlers(object):
+ def __init__(self, hs):
+ self.user_directory_handler = UserDirectoryHandler(hs)
+
+
+class UserDirectoryTestCase(unittest.TestCase):
+ """ Tests the UserDirectoryHandler. """
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield setup_test_homeserver(self.addCleanup)
+ self.store = hs.get_datastore()
+ hs.handlers = UserDirectoryHandlers(hs)
+
+ self.handler = hs.get_handlers().user_directory_handler
+
+ @defer.inlineCallbacks
+ def test_handle_local_profile_change_with_support_user(self):
+ support_user_id = "@support:test"
+ yield self.store.register(
+ user_id=support_user_id,
+ token="123",
+ password_hash=None,
+ user_type=UserTypes.SUPPORT
+ )
+
+ yield self.handler.handle_local_profile_change(support_user_id, None)
+ profile = yield self.store.get_user_in_directory(support_user_id)
+ self.assertTrue(profile is None)
+ display_name = 'display_name'
+
+ profile_info = ProfileInfo(
+ avatar_url='avatar_url',
+ display_name=display_name,
+ )
+ regular_user_id = '@regular:test'
+ yield self.handler.handle_local_profile_change(regular_user_id, profile_info)
+ profile = yield self.store.get_user_in_directory(regular_user_id)
+ self.assertTrue(profile['display_name'] == display_name)
+
+ @defer.inlineCallbacks
+ def test_handle_user_deactivated_support_user(self):
+ s_user_id = "@support:test"
+ self.store.register(
+ user_id=s_user_id,
+ token="123",
+ password_hash=None,
+ user_type=UserTypes.SUPPORT
+ )
+
+ self.store.remove_from_user_dir = Mock()
+ self.store.remove_from_user_in_public_room = Mock()
+ yield self.handler.handle_user_deactivated(s_user_id)
+ self.store.remove_from_user_dir.not_called()
+ self.store.remove_from_user_in_public_room.not_called()
+
+ @defer.inlineCallbacks
+ def test_handle_user_deactivated_regular_user(self):
+ r_user_id = "@regular:test"
+ self.store.register(user_id=r_user_id, token="123", password_hash=None)
+ self.store.remove_from_user_dir = Mock()
+ self.store.remove_from_user_in_public_room = Mock()
+ yield self.handler.handle_user_deactivated(r_user_id)
+ self.store.remove_from_user_dir.called_once_with(r_user_id)
+ self.store.remove_from_user_in_public_room.assert_called_once_with(r_user_id)
diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py
new file mode 100644
index 0000000000..0f613945c8
--- /dev/null
+++ b/tests/patch_inline_callbacks.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import functools
+import sys
+
+from twisted.internet import defer
+from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
+
+
+def do_patch():
+ """
+ Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
+ """
+
+ from synapse.util.logcontext import LoggingContext
+
+ orig_inline_callbacks = defer.inlineCallbacks
+
+ def new_inline_callbacks(f):
+
+ orig = orig_inline_callbacks(f)
+
+ @functools.wraps(f)
+ def wrapped(*args, **kwargs):
+ start_context = LoggingContext.current_context()
+
+ try:
+ res = orig(*args, **kwargs)
+ except Exception:
+ if LoggingContext.current_context() != start_context:
+ err = "%s changed context from %s to %s on exception" % (
+ f, start_context, LoggingContext.current_context()
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ raise
+
+ if not isinstance(res, Deferred) or res.called:
+ if LoggingContext.current_context() != start_context:
+ err = "%s changed context from %s to %s" % (
+ f, start_context, LoggingContext.current_context()
+ )
+ # print the error to stderr because otherwise all we
+ # see in travis-ci is the 500 error
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return res
+
+ if LoggingContext.current_context() != LoggingContext.sentinel:
+ err = (
+ "%s returned incomplete deferred in non-sentinel context "
+ "%s (start was %s)"
+ ) % (
+ f, LoggingContext.current_context(), start_context,
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+
+ def check_ctx(r):
+ if LoggingContext.current_context() != start_context:
+ err = "%s completion of %s changed context from %s to %s" % (
+ "Failure" if isinstance(r, Failure) else "Success",
+ f, start_context, LoggingContext.current_context(),
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return r
+
+ res.addBoth(check_ctx)
+ return res
+
+ return wrapped
+
+ defer.inlineCallbacks = new_inline_callbacks
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index addc01ab7f..6dc45e8506 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -18,6 +18,7 @@ from mock import Mock
from twisted.internet.defer import Deferred
from synapse.rest.client.v1 import admin, login, room
+from synapse.util.logcontext import make_deferred_yieldable
from tests.unittest import HomeserverTestCase
@@ -47,7 +48,7 @@ class HTTPPusherTests(HomeserverTestCase):
def post_json_get_json(url, body):
d = Deferred()
self.push_attempts.append((d, url, body))
- return d
+ return make_deferred_yieldable(d)
m.post_json_get_json = post_json_get_json
diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py
index e38eb628a9..407bf0ac4c 100644
--- a/tests/rest/client/v1/test_admin.py
+++ b/tests/rest/client/v1/test_admin.py
@@ -19,6 +19,7 @@ import json
from mock import Mock
+from synapse.api.constants import UserTypes
from synapse.rest.client.v1.admin import register_servlets
from tests import unittest
@@ -147,7 +148,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
- want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
+ want_mac.update(
+ nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin\x00support"
+ )
want_mac = want_mac.hexdigest()
body = json.dumps(
@@ -156,6 +159,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"username": "bob",
"password": "abc123",
"admin": True,
+ "user_type": UserTypes.SUPPORT,
"mac": want_mac,
}
)
@@ -174,7 +178,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
- want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
+ want_mac.update(
+ nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin"
+ )
want_mac = want_mac.hexdigest()
body = json.dumps(
@@ -202,8 +208,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
def test_missing_parts(self):
"""
Synapse will complain if you don't give nonce, username, password, and
- mac. Admin is optional. Additional checks are done for length and
- type.
+ mac. Admin and user_types are optional. Additional checks are done for length
+ and type.
"""
def nonce():
@@ -260,7 +266,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual('Invalid username', channel.json_body["error"])
#
- # Username checks
+ # Password checks
#
# Must be present
@@ -296,3 +302,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
+
+ #
+ # user_type check
+ #
+
+ # Invalid user_type
+ body = json.dumps({
+ "nonce": nonce(),
+ "username": "a",
+ "password": "1234",
+ "user_type": "invalid"}
+ )
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual('Invalid user type', channel.json_body["error"])
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
new file mode 100644
index 0000000000..7fa120a10f
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet.defer import succeed
+
+from synapse.api.constants import LoginType
+from synapse.rest.client.v1 import admin
+from synapse.rest.client.v2_alpha import auth, register
+
+from tests import unittest
+
+
+class FallbackAuthTests(unittest.HomeserverTestCase):
+
+ servlets = [
+ auth.register_servlets,
+ admin.register_servlets,
+ register.register_servlets,
+ ]
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+
+ config = self.default_config()
+
+ config.enable_registration_captcha = True
+ config.recaptcha_public_key = "brokencake"
+ config.registrations_require_3pid = []
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ auth_handler = hs.get_auth_handler()
+
+ self.recaptcha_attempts = []
+
+ def _recaptcha(authdict, clientip):
+ self.recaptcha_attempts.append((authdict, clientip))
+ return succeed(True)
+
+ auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha
+
+ @unittest.INFO
+ def test_fallback_captcha(self):
+
+ request, channel = self.make_request(
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ )
+ self.render(request)
+
+ # Returns a 401 as per the spec
+ self.assertEqual(request.code, 401)
+ # Grab the session
+ session = channel.json_body["session"]
+ # Assert our configured public key is being given
+ self.assertEqual(
+ channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
+ )
+
+ request, channel = self.make_request(
+ "GET", "auth/m.login.recaptcha/fallback/web?session=" + session
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
+
+ request, channel = self.make_request(
+ "POST",
+ "auth/m.login.recaptcha/fallback/web?session="
+ + session
+ + "&g-recaptcha-response=a",
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
+
+ # The recaptcha handler is called with the response given
+ self.assertEqual(len(self.recaptcha_attempts), 1)
+ self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
+
+ # Now we have fufilled the recaptcha fallback step, we can then send a
+ # request to the register API with the session in the authdict.
+ request, channel = self.make_request(
+ "POST", "register", {"auth": {"session": session}}
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel.json_body["user_id"], "@user:test")
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index a86901c2d8..ad5e9a612f 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -17,15 +17,21 @@
import os
import shutil
import tempfile
+from binascii import unhexlify
from mock import Mock
+from six.moves.urllib import parse
from twisted.internet import defer, reactor
+from twisted.internet.defer import Deferred
+from synapse.config.repository import MediaStorageProviderConfig
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
+from synapse.util.logcontext import make_deferred_yieldable
+from synapse.util.module_loader import load_module
from tests import unittest
@@ -83,3 +89,143 @@ class MediaStorageTests(unittest.TestCase):
body = f.read()
self.assertEqual(test_body, body)
+
+
+class MediaRepoTests(unittest.HomeserverTestCase):
+
+ hijack_auth = True
+ user_id = "@test:user"
+
+ def make_homeserver(self, reactor, clock):
+
+ self.fetches = []
+
+ def get_file(destination, path, output_stream, args=None, max_size=None):
+ """
+ Returns tuple[int,dict,str,int] of file length, response headers,
+ absolute URI, and response code.
+ """
+
+ def write_to(r):
+ data, response = r
+ output_stream.write(data)
+ return response
+
+ d = Deferred()
+ d.addCallback(write_to)
+ self.fetches.append((d, destination, path, args))
+ return make_deferred_yieldable(d)
+
+ client = Mock()
+ client.get_file = get_file
+
+ self.storage_path = self.mktemp()
+ os.mkdir(self.storage_path)
+
+ config = self.default_config()
+ config.media_store_path = self.storage_path
+ config.thumbnail_requirements = {}
+ config.max_image_pixels = 2000000
+
+ provider_config = {
+ "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+ "store_local": True,
+ "store_synchronous": False,
+ "store_remote": True,
+ "config": {"directory": self.storage_path},
+ }
+
+ loaded = list(load_module(provider_config)) + [
+ MediaStorageProviderConfig(False, False, False)
+ ]
+
+ config.media_storage_providers = [loaded]
+
+ hs = self.setup_test_homeserver(config=config, http_client=client)
+
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+
+ self.media_repo = hs.get_media_repository_resource()
+ self.download_resource = self.media_repo.children[b'download']
+
+ # smol png
+ self.end_content = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ def _req(self, content_disposition):
+
+ request, channel = self.make_request(
+ "GET", "example.com/12345", shorthand=False
+ )
+ request.render(self.download_resource)
+ self.pump()
+
+ # We've made one fetch, to example.com, using the media URL, and asking
+ # the other server not to do a remote fetch
+ self.assertEqual(len(self.fetches), 1)
+ self.assertEqual(self.fetches[0][1], "example.com")
+ self.assertEqual(
+ self.fetches[0][2], "/_matrix/media/v1/download/example.com/12345"
+ )
+ self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
+
+ headers = {
+ b"Content-Length": [b"%d" % (len(self.end_content))],
+ b"Content-Type": [b'image/png'],
+ }
+ if content_disposition:
+ headers[b"Content-Disposition"] = [content_disposition]
+
+ self.fetches[0][0].callback(
+ (self.end_content, (len(self.end_content), headers))
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
+ return channel
+
+ def test_disposition_filename_ascii(self):
+ """
+ If the filename is filename=<ascii> then Synapse will decode it as an
+ ASCII string, and use filename= in the response.
+ """
+ channel = self._req(b"inline; filename=out.png")
+
+ headers = channel.headers
+ self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Disposition"), [b"inline; filename=out.png"]
+ )
+
+ def test_disposition_filenamestar_utf8escaped(self):
+ """
+ If the filename is filename=*utf8''<utf8 escaped> then Synapse will
+ correctly decode it as the UTF-8 string, and use filename* in the
+ response.
+ """
+ filename = parse.quote(u"\u2603".encode('utf8')).encode('ascii')
+ channel = self._req(b"inline; filename*=utf-8''" + filename + b".png")
+
+ headers = channel.headers
+ self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Disposition"),
+ [b"inline; filename*=utf-8''" + filename + b".png"],
+ )
+
+ def test_disposition_none(self):
+ """
+ If there is no filename, one isn't passed on in the Content-Disposition
+ of the request.
+ """
+ channel = self._req(None)
+
+ headers = channel.headers
+ self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+ self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 29579cf091..650ce95a6f 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -15,20 +15,55 @@
import os
-from mock import Mock
+import attr
+from netaddr import IPSet
-from twisted.internet.defer import Deferred
+from twisted.internet._resolver import HostResolution
+from twisted.internet.address import IPv4Address, IPv6Address
+from twisted.internet.error import DNSLookupError
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import AccumulatingProtocol
+from twisted.web._newclient import ResponseDone
from synapse.config.repository import MediaStorageProviderConfig
from synapse.util.module_loader import load_module
from tests import unittest
+from tests.server import FakeTransport
+
+
+@attr.s
+class FakeResponse(object):
+ version = attr.ib()
+ code = attr.ib()
+ phrase = attr.ib()
+ headers = attr.ib()
+ body = attr.ib()
+ absoluteURI = attr.ib()
+
+ @property
+ def request(self):
+ @attr.s
+ class FakeTransport(object):
+ absoluteURI = self.absoluteURI
+
+ return FakeTransport()
+
+ def deliverBody(self, protocol):
+ protocol.dataReceived(self.body)
+ protocol.connectionLost(Failure(ResponseDone()))
class URLPreviewTests(unittest.HomeserverTestCase):
hijack_auth = True
user_id = "@test:user"
+ end_content = (
+ b'<html><head>'
+ b'<meta property="og:title" content="~matrix~" />'
+ b'<meta property="og:description" content="hi" />'
+ b'</head></html>'
+ )
def make_homeserver(self, reactor, clock):
@@ -38,6 +73,15 @@ class URLPreviewTests(unittest.HomeserverTestCase):
config = self.default_config()
config.url_preview_enabled = True
config.max_spider_size = 9999999
+ config.url_preview_ip_range_blacklist = IPSet(
+ (
+ "192.168.1.1",
+ "1.0.0.0/8",
+ "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff",
+ "2001:800::/21",
+ )
+ )
+ config.url_preview_ip_range_whitelist = IPSet(("1.1.1.1",))
config.url_preview_url_blacklist = []
config.media_store_path = self.storage_path
@@ -61,104 +105,366 @@ class URLPreviewTests(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
- self.fetches = []
+ self.media_repo = hs.get_media_repository_resource()
+ self.preview_url = self.media_repo.children[b'preview_url']
- def get_file(url, output_stream, max_size):
- """
- Returns tuple[int,dict,str,int] of file length, response headers,
- absolute URI, and response code.
- """
+ self.lookups = {}
- def write_to(r):
- data, response = r
- output_stream.write(data)
- return response
+ class Resolver(object):
+ def resolveHostName(
+ _self,
+ resolutionReceiver,
+ hostName,
+ portNumber=0,
+ addressTypes=None,
+ transportSemantics='TCP',
+ ):
- d = Deferred()
- d.addCallback(write_to)
- self.fetches.append((d, url))
- return d
+ resolution = HostResolution(hostName)
+ resolutionReceiver.resolutionBegan(resolution)
+ if hostName not in self.lookups:
+ raise DNSLookupError("OH NO")
- client = Mock()
- client.get_file = get_file
+ for i in self.lookups[hostName]:
+ resolutionReceiver.addressResolved(i[0]('TCP', i[1], portNumber))
+ resolutionReceiver.resolutionComplete()
+ return resolutionReceiver
- self.media_repo = hs.get_media_repository_resource()
- preview_url = self.media_repo.children[b'preview_url']
- preview_url.client = client
- self.preview_url = preview_url
+ self.reactor.nameResolver = Resolver()
def test_cache_returns_correct_type(self):
+ self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://matrix.org", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+ % (len(self.end_content),)
+ + self.end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+ )
+
+ # Check the cache returns the correct response
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://matrix.org", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ # Check the cache response has the same content
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+ )
+
+ # Clear the in-memory cache
+ self.assertIn("http://matrix.org", self.preview_url._cache)
+ self.preview_url._cache.pop("http://matrix.org")
+ self.assertNotIn("http://matrix.org", self.preview_url._cache)
+ # Check the database cache returns the correct response
request, channel = self.make_request(
- "GET", "url_preview?url=matrix.org", shorthand=False
+ "GET", "url_preview?url=http://matrix.org", shorthand=False
)
request.render(self.preview_url)
self.pump()
- # We've made one fetch
- self.assertEqual(len(self.fetches), 1)
+ # Check the cache response has the same content
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+ )
+
+ def test_non_ascii_preview_httpequiv(self):
+ self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
end_content = (
b'<html><head>'
- b'<meta property="og:title" content="~matrix~" />'
+ b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>'
+ b'<meta property="og:title" content="\xe4\xea\xe0" />'
b'<meta property="og:description" content="hi" />'
b'</head></html>'
)
- self.fetches[0][0].callback(
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://matrix.org", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
(
- end_content,
- (
- len(end_content),
- {
- b"Content-Length": [b"%d" % (len(end_content))],
- b"Content-Type": [b'text/html; charset="utf8"'],
- },
- "https://example.com",
- 200,
- ),
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b"Content-Type: text/html; charset=\"utf8\"\r\n\r\n"
)
+ % (len(end_content),)
+ + end_content
)
self.pump()
self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+ self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
+
+ def test_non_ascii_preview_content_type(self):
+ self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+
+ end_content = (
+ b'<html><head>'
+ b'<meta property="og:title" content="\xe4\xea\xe0" />'
+ b'<meta property="og:description" content="hi" />'
+ b'</head></html>'
)
- # Check the cache returns the correct response
request, channel = self.make_request(
- "GET", "url_preview?url=matrix.org", shorthand=False
+ "GET", "url_preview?url=http://matrix.org", shorthand=False
)
request.render(self.preview_url)
self.pump()
- # Only one fetch, still, since we'll lean on the cache
- self.assertEqual(len(self.fetches), 1)
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b"Content-Type: text/html; charset=\"windows-1251\"\r\n\r\n"
+ )
+ % (len(end_content),)
+ + end_content
+ )
- # Check the cache response has the same content
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
+
+ def test_ipaddr(self):
+ """
+ IP addresses can be previewed directly.
+ """
+ self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://example.com", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+ % (len(self.end_content),)
+ + self.end_content
+ )
+
+ self.pump()
self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
- # Clear the in-memory cache
- self.assertIn("matrix.org", self.preview_url._cache)
- self.preview_url._cache.pop("matrix.org")
- self.assertNotIn("matrix.org", self.preview_url._cache)
+ def test_blacklisted_ip_specific(self):
+ """
+ Blacklisted IP addresses, found via DNS, are not spidered.
+ """
+ self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
- # Check the database cache returns the correct response
request, channel = self.make_request(
- "GET", "url_preview?url=matrix.org", shorthand=False
+ "GET", "url_preview?url=http://example.com", shorthand=False
)
request.render(self.preview_url)
self.pump()
- # Only one fetch, still, since we'll lean on the cache
- self.assertEqual(len(self.fetches), 1)
+ # No requests made.
+ self.assertEqual(len(self.reactor.tcpClients), 0)
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(
+ channel.json_body,
+ {
+ 'errcode': 'M_UNKNOWN',
+ 'error': 'IP address blocked by IP blacklist entry',
+ },
+ )
- # Check the cache response has the same content
+ def test_blacklisted_ip_range(self):
+ """
+ Blacklisted IP ranges, IPs found over DNS, are not spidered.
+ """
+ self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
+
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://example.com", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(
+ channel.json_body,
+ {
+ 'errcode': 'M_UNKNOWN',
+ 'error': 'IP address blocked by IP blacklist entry',
+ },
+ )
+
+ def test_blacklisted_ip_specific_direct(self):
+ """
+ Blacklisted IP addresses, accessed directly, are not spidered.
+ """
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://192.168.1.1", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ # No requests made.
+ self.assertEqual(len(self.reactor.tcpClients), 0)
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(
+ channel.json_body,
+ {
+ 'errcode': 'M_UNKNOWN',
+ 'error': 'IP address blocked by IP blacklist entry',
+ },
+ )
+
+ def test_blacklisted_ip_range_direct(self):
+ """
+ Blacklisted IP ranges, accessed directly, are not spidered.
+ """
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://1.1.1.2", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(
+ channel.json_body,
+ {
+ 'errcode': 'M_UNKNOWN',
+ 'error': 'IP address blocked by IP blacklist entry',
+ },
+ )
+
+ def test_blacklisted_ip_range_whitelisted_ip(self):
+ """
+ Blacklisted but then subsequently whitelisted IP addresses can be
+ spidered.
+ """
+ self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
+
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://example.com", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+ % (len(self.end_content),)
+ + self.end_content
+ )
+
+ self.pump()
self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
+
+ def test_blacklisted_ip_with_external_ip(self):
+ """
+ If a hostname resolves a blacklisted IP, even if there's a
+ non-blacklisted one, it will be rejected.
+ """
+ # Hardcode the URL resolving to the IP we want.
+ self.lookups[u"example.com"] = [
+ (IPv4Address, "1.1.1.2"),
+ (IPv4Address, "8.8.8.8"),
+ ]
+
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://example.com", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(
+ channel.json_body,
+ {
+ 'errcode': 'M_UNKNOWN',
+ 'error': 'IP address blocked by IP blacklist entry',
+ },
+ )
+
+ def test_blacklisted_ipv6_specific(self):
+ """
+ Blacklisted IP addresses, found via DNS, are not spidered.
+ """
+ self.lookups["example.com"] = [
+ (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
+ ]
+
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://example.com", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ # No requests made.
+ self.assertEqual(len(self.reactor.tcpClients), 0)
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(
+ channel.json_body,
+ {
+ 'errcode': 'M_UNKNOWN',
+ 'error': 'IP address blocked by IP blacklist entry',
+ },
+ )
+
+ def test_blacklisted_ipv6_range(self):
+ """
+ Blacklisted IP ranges, IPs found over DNS, are not spidered.
+ """
+ self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
+
+ request, channel = self.make_request(
+ "GET", "url_preview?url=http://example.com", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(
+ channel.json_body,
+ {
+ 'errcode': 'M_UNKNOWN',
+ 'error': 'IP address blocked by IP blacklist entry',
+ },
+ )
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
new file mode 100644
index 0000000000..8d8f03e005
--- /dev/null
+++ b/tests/rest/test_well_known.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.rest.well_known import WellKnownResource
+
+from tests import unittest
+
+
+class WellKnownTests(unittest.HomeserverTestCase):
+ def setUp(self):
+ super(WellKnownTests, self).setUp()
+
+ # replace the JsonResource with a WellKnownResource
+ self.resource = WellKnownResource(self.hs)
+
+ def test_well_known(self):
+ self.hs.config.public_baseurl = "https://tesths"
+ self.hs.config.default_identity_server = "https://testis"
+
+ request, channel = self.make_request(
+ "GET",
+ "/.well-known/matrix/client",
+ shorthand=False,
+ )
+ self.render(request)
+
+ self.assertEqual(request.code, 200)
+ self.assertEqual(
+ channel.json_body, {
+ "m.homeserver": {"base_url": "https://tesths"},
+ "m.identity_server": {"base_url": "https://testis"},
+ }
+ )
+
+ def test_well_known_no_public_baseurl(self):
+ self.hs.config.public_baseurl = None
+
+ request, channel = self.make_request(
+ "GET",
+ "/.well-known/matrix/client",
+ shorthand=False,
+ )
+ self.render(request)
+
+ self.assertEqual(request.code, 404)
diff --git a/tests/server.py b/tests/server.py
index 7919a1f124..db43fa0db8 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -14,6 +14,8 @@ from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactorClock
+from twisted.web.http import unquote
+from twisted.web.http_headers import Headers
from synapse.http.site import SynapseRequest
from synapse.util import Clock
@@ -50,6 +52,15 @@ class FakeChannel(object):
raise Exception("No result yet.")
return int(self.result["code"])
+ @property
+ def headers(self):
+ if not self.result:
+ raise Exception("No result yet.")
+ h = Headers()
+ for i in self.result["headers"]:
+ h.addRawHeader(*i)
+ return h
+
def writeHeaders(self, version, code, reason, headers):
self.result["version"] = version
self.result["code"] = code
@@ -152,6 +163,9 @@ def make_request(
path = b"/_matrix/client/r0/" + path
path = path.replace(b"//", b"/")
+ if not path.startswith(b"/"):
+ path = b"/" + path
+
if isinstance(content, text_type):
content = content.encode('utf8')
@@ -161,6 +175,7 @@ def make_request(
req = request(site, channel)
req.process = lambda: b""
req.content = BytesIO(content)
+ req.postpath = list(map(unquote, path[1:].split(b'/')))
if access_token:
req.requestHeaders.addRawHeader(
@@ -368,8 +383,16 @@ class FakeTransport(object):
self.disconnecting = True
def pauseProducing(self):
+ if not self.producer:
+ return
+
self.producer.pauseProducing()
+ def resumeProducing(self):
+ if not self.producer:
+ return
+ self.producer.resumeProducing()
+
def unregisterProducer(self):
if not self.producer:
return
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 832e379a83..9605301b59 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -16,6 +16,8 @@ from mock import Mock
from twisted.internet import defer
+from synapse.api.constants import UserTypes
+
from tests.unittest import HomeserverTestCase
FORTY_DAYS = 40 * 24 * 60 * 60
@@ -28,6 +30,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
self.store = hs.get_datastore()
hs.config.limit_usage_by_mau = True
hs.config.max_mau_value = 50
+
# Advance the clock a bit
reactor.advance(FORTY_DAYS)
@@ -39,14 +42,23 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
user1_email = "user1@matrix.org"
user2 = "@user2:server"
user2_email = "user2@matrix.org"
+ user3 = "@user3:server"
+ user3_email = "user3@matrix.org"
+
threepids = [
{'medium': 'email', 'address': user1_email},
{'medium': 'email', 'address': user2_email},
+ {'medium': 'email', 'address': user3_email},
]
- user_num = len(threepids)
+ # -1 because user3 is a support user and does not count
+ user_num = len(threepids) - 1
self.store.register(user_id=user1, token="123", password_hash=None)
self.store.register(user_id=user2, token="456", password_hash=None)
+ self.store.register(
+ user_id=user3, token="789",
+ password_hash=None, user_type=UserTypes.SUPPORT
+ )
self.pump()
now = int(self.hs.get_clock().time_msec())
@@ -60,7 +72,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
active_count = self.store.get_monthly_active_count()
- # Test total counts
+ # Test total counts, ensure user3 (support user) is not counted
self.assertEquals(self.get_success(active_count), user_num)
# Test user is marked as active
@@ -149,7 +161,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
def test_populate_monthly_users_is_guest(self):
# Test that guest users are not added to mau list
- user_id = "user_id"
+ user_id = "@user_id:host"
self.store.register(
user_id=user_id, token="123", password_hash=None, make_guest=True
)
@@ -220,3 +232,46 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
self.store.user_add_threepid(user2, "email", user2_email, now, now)
count = self.store.get_registered_reserved_users_count()
self.assertEquals(self.get_success(count), len(threepids))
+
+ def test_support_user_not_add_to_mau_limits(self):
+ support_user_id = "@support:test"
+ count = self.store.get_monthly_active_count()
+ self.pump()
+ self.assertEqual(self.get_success(count), 0)
+
+ self.store.register(
+ user_id=support_user_id,
+ token="123",
+ password_hash=None,
+ user_type=UserTypes.SUPPORT
+ )
+
+ self.store.upsert_monthly_active_user(support_user_id)
+ count = self.store.get_monthly_active_count()
+ self.pump()
+ self.assertEqual(self.get_success(count), 0)
+
+ def test_track_monthly_users_without_cap(self):
+ self.hs.config.limit_usage_by_mau = False
+ self.hs.config.mau_stats_only = True
+ self.hs.config.max_mau_value = 1 # should not matter
+
+ count = self.store.get_monthly_active_count()
+ self.assertEqual(0, self.get_success(count))
+
+ self.store.upsert_monthly_active_user("@user1:server")
+ self.store.upsert_monthly_active_user("@user2:server")
+ self.pump()
+
+ count = self.store.get_monthly_active_count()
+ self.assertEqual(2, self.get_success(count))
+
+ def test_no_users_when_not_tracking(self):
+ self.hs.config.limit_usage_by_mau = False
+ self.hs.config.mau_stats_only = False
+ self.store.upsert_monthly_active_user = Mock()
+
+ self.store.populate_monthly_active_users("@user:sever")
+ self.pump()
+
+ self.store.upsert_monthly_active_user.assert_not_called()
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 3dfb7b903a..cb3cc4d2e5 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -16,6 +16,8 @@
from twisted.internet import defer
+from synapse.api.constants import UserTypes
+
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -99,6 +101,26 @@ class RegistrationStoreTestCase(unittest.TestCase):
user = yield self.store.get_user_by_access_token(self.tokens[0])
self.assertIsNone(user, "access token was not deleted without device_id")
+ @defer.inlineCallbacks
+ def test_is_support_user(self):
+ TEST_USER = "@test:test"
+ SUPPORT_USER = "@support:test"
+
+ res = yield self.store.is_support_user(None)
+ self.assertFalse(res)
+ yield self.store.register(user_id=TEST_USER, token="123", password_hash=None)
+ res = yield self.store.is_support_user(TEST_USER)
+ self.assertFalse(res)
+
+ yield self.store.register(
+ user_id=SUPPORT_USER,
+ token="456",
+ password_hash=None,
+ user_type=UserTypes.SUPPORT
+ )
+ res = yield self.store.is_support_user(SUPPORT_USER)
+ self.assertTrue(res)
+
class TokenGenerator:
def __init__(self):
diff --git a/tests/test_federation.py b/tests/test_federation.py
index e1a34ccffd..1a5dc32c88 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -123,8 +123,8 @@ class MessageAcceptTests(unittest.TestCase):
"test.serv", lying_event, sent_to_us_directly=True
)
- # Step the reactor, so the database fetches come back
- self.reactor.advance(1)
+ # Step the reactor, so the database fetches come back
+ self.reactor.advance(1)
# on_receive_pdu should throw an error
failure = self.failureResultOf(d)
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 0afdeb0818..04f95c942f 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -171,6 +171,24 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ def test_tracked_but_not_limited(self):
+ self.hs.config.max_mau_value = 1 # should not matter
+ self.hs.config.limit_usage_by_mau = False
+ self.hs.config.mau_stats_only = True
+
+ # Simply being able to create 2 users indicates that the
+ # limit was not reached.
+ token1 = self.create_user("kermit1")
+ self.do_sync_for_user(token1)
+ token2 = self.create_user("kermit2")
+ self.do_sync_for_user(token2)
+
+ # We do want to verify that the number of tracked users
+ # matches what we want though
+ count = self.store.get_monthly_active_count()
+ self.reactor.advance(100)
+ self.assertEqual(2, self.successResultOf(count))
+
def create_user(self, localpart):
request_data = json.dumps(
{
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 17897711a1..0ff6d0e283 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -19,6 +19,28 @@ from synapse.metrics import InFlightGauge
from tests import unittest
+def get_sample_labels_value(sample):
+ """ Extract the labels and values of a sample.
+
+ prometheus_client 0.5 changed the sample type to a named tuple with more
+ members than the plain tuple had in 0.4 and earlier. This function can
+ extract the labels and value from the sample for both sample types.
+
+ Args:
+ sample: The sample to get the labels and value from.
+ Returns:
+ A tuple of (labels, value) from the sample.
+ """
+
+ # If the sample has a labels and value attribute, use those.
+ if hasattr(sample, "labels") and hasattr(sample, "value"):
+ return sample.labels, sample.value
+ # Otherwise fall back to treating it as a plain 3 tuple.
+ else:
+ _, labels, value = sample
+ return labels, value
+
+
class TestMauLimit(unittest.TestCase):
def test_basic(self):
gauge = InFlightGauge(
@@ -75,7 +97,7 @@ class TestMauLimit(unittest.TestCase):
for r in gauge.collect():
results[r.name] = {
tuple(labels[x] for x in gauge.labels): value
- for _, labels, value in r.samples
+ for labels, value in map(get_sample_labels_value, r.samples)
}
return results
diff --git a/tests/test_server.py b/tests/test_server.py
index f0e6291b7e..634a8fbca5 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -27,6 +27,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite, logger
from synapse.util import Clock
+from synapse.util.logcontext import make_deferred_yieldable
from tests import unittest
from tests.server import FakeTransport, make_request, render, setup_test_homeserver
@@ -95,7 +96,7 @@ class JsonResourceTests(unittest.TestCase):
d = Deferred()
d.addCallback(_throw)
self.reactor.callLater(1, d.callback, True)
- return d
+ return make_deferred_yieldable(d)
res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 9ecc3ef14f..0968e86a7b 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -43,7 +43,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
def test_ui_auth(self):
self.hs.config.user_consent_at_registration = True
self.hs.config.user_consent_policy_name = "My Cool Privacy Policy"
- self.hs.config.public_baseurl = "https://example.org"
+ self.hs.config.public_baseurl = "https://example.org/"
self.hs.config.user_consent_version = "1.0"
# Do a UI auth request
diff --git a/tests/test_types.py b/tests/test_types.py
index 0f5c8bfaf9..d314a7ff58 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -14,7 +14,7 @@
# limitations under the License.
from synapse.api.errors import SynapseError
-from synapse.types import GroupID, RoomAlias, UserID
+from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
from tests import unittest
from tests.utils import TestHomeServer
@@ -79,3 +79,32 @@ class GroupIDTestCase(unittest.TestCase):
except SynapseError as exc:
self.assertEqual(400, exc.code)
self.assertEqual("M_UNKNOWN", exc.errcode)
+
+
+class MapUsernameTestCase(unittest.TestCase):
+ def testPassThrough(self):
+ self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
+
+ def testUpperCase(self):
+ self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
+ self.assertEqual(
+ map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
+ "t_e_s_t__1234",
+ )
+
+ def testSymbols(self):
+ self.assertEqual(
+ map_username_to_mxid_localpart("test=$?_1234"),
+ "test=3d=24=3f_1234",
+ )
+
+ def testLeadingUnderscore(self):
+ self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")
+
+ def testNonAscii(self):
+ # this should work with either a unicode or a bytes
+ self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast")
+ self.assertEqual(
+ map_username_to_mxid_localpart(u'têst'.encode('utf-8')),
+ "t=c3=aast",
+ )
diff --git a/tests/unittest.py b/tests/unittest.py
index a9ce57da9a..78d2f740f9 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import gc
import hashlib
import hmac
import logging
@@ -31,10 +31,12 @@ from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
-from synapse.util.logcontext import LoggingContextFilter
+from synapse.util.logcontext import LoggingContext, LoggingContextFilter
from tests.server import get_clock, make_request, render, setup_test_homeserver
-from tests.utils import default_config
+from tests.utils import default_config, setupdb
+
+setupdb()
# Set up putting Synapse's logs into Trial's.
rootLogger = logging.getLogger()
@@ -102,8 +104,16 @@ class TestCase(unittest.TestCase):
# traceback when a unit test exits leaving things on the reactor.
twisted.internet.base.DelayedCall.debug = True
- old_level = logging.getLogger().level
+ # if we're not starting in the sentinel logcontext, then to be honest
+ # all future bets are off.
+ if LoggingContext.current_context() is not LoggingContext.sentinel:
+ self.fail(
+ "Test starting with non-sentinel logging context %s" % (
+ LoggingContext.current_context(),
+ )
+ )
+ old_level = logging.getLogger().level
if old_level != level:
@around(self)
@@ -115,6 +125,16 @@ class TestCase(unittest.TestCase):
logging.getLogger().setLevel(level)
return orig()
+ @around(self)
+ def tearDown(orig):
+ ret = orig()
+ # force a GC to workaround problems with deferreds leaking logcontexts when
+ # they are GCed (see the logcontext docs)
+ gc.collect()
+ LoggingContext.set_current_context(LoggingContext.sentinel)
+
+ return ret
+
def assertObjectHasAttributes(self, attrs, obj):
"""Asserts that the given object has each of the attributes given, and
that the value of each matches according to assertEquals."""
@@ -353,6 +373,7 @@ class HomeserverTestCase(TestCase):
nonce_str += b"\x00admin"
else:
nonce_str += b"\x00notadmin"
+
want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str)
want_mac = want_mac.hexdigest()
diff --git a/tests/utils.py b/tests/utils.py
index 67ab916f30..08d6faa0a6 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -134,10 +134,14 @@ def default_config(name):
config.hs_disabled_limit_type = ""
config.max_mau_value = 50
config.mau_trial_days = 0
+ config.mau_stats_only = False
config.mau_limits_reserved_threepids = []
config.admin_contact = None
config.rc_messages_per_second = 10000
config.rc_message_burst_count = 10000
+ config.saml2_enabled = False
+ config.public_baseurl = None
+ config.default_identity_server = None
config.use_frozen_dicts = False
|