diff --git a/tests/__init__.py b/tests/__init__.py
index 9d9ca22829..d3181f9403 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +16,9 @@
from twisted.trial import util
-from tests import utils
+import tests.patch_inline_callbacks
+
+# attempt to do the patch before we load any synapse code
+tests.patch_inline_callbacks.do_patch()
util.DEFAULT_TIMEOUT_DURATION = 10
-utils.setupdb()
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 48b2d3d663..2a7044801a 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -60,7 +60,7 @@ class FilteringTestCase(unittest.TestCase):
invalid_filters = [
{"boom": {}},
{"account_data": "Hello World"},
- {"event_fields": ["\\foo"]},
+ {"event_fields": [r"\\foo"]},
{"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
{"event_format": "other"},
{"room": {"not_rooms": ["#foo:pik-test"]}},
@@ -109,6 +109,16 @@ class FilteringTestCase(unittest.TestCase):
"event_format": "client",
"event_fields": ["type", "content", "sender"],
},
+
+ # a single backslash should be permitted (though it is debatable whether
+ # it should be permitted before anything other than `.`, and what that
+ # actually means)
+ #
+ # (note that event_fields is implemented in
+ # synapse.events.utils.serialize_event, and so whether this actually works
+ # is tested elsewhere. We just want to check that it is allowed through the
+ # filter validation)
+ {"event_fields": [r"foo\.bar"]},
]
for filter in valid_filters:
try:
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index f88d28a19d..0c23068bcf 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -67,6 +67,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
with open(log_config_file) as f:
config = f.read()
# find the 'filename' line
- matches = re.findall("^\s*filename:\s*(.*)$", config, re.M)
+ matches = re.findall(r"^\s*filename:\s*(.*)$", config, re.M)
self.assertEqual(1, len(matches))
self.assertEqual(matches[0], expected)
diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py
new file mode 100644
index 0000000000..f37a17d618
--- /dev/null
+++ b/tests/config/test_room_directory.py
@@ -0,0 +1,67 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import yaml
+
+from synapse.config.room_directory import RoomDirectoryConfig
+
+from tests import unittest
+
+
+class RoomDirectoryConfigTestCase(unittest.TestCase):
+ def test_alias_creation_acl(self):
+ config = yaml.load("""
+ alias_creation_rules:
+ - user_id: "*bob*"
+ alias: "*"
+ action: "deny"
+ - user_id: "*"
+ alias: "#unofficial_*"
+ action: "allow"
+ - user_id: "@foo*:example.com"
+ alias: "*"
+ action: "allow"
+ - user_id: "@gah:example.com"
+ alias: "#goo:example.com"
+ action: "allow"
+ """)
+
+ rd_config = RoomDirectoryConfig()
+ rd_config.read_config(config)
+
+ self.assertFalse(rd_config.is_alias_creation_allowed(
+ user_id="@bob:example.com",
+ alias="#test:example.com",
+ ))
+
+ self.assertTrue(rd_config.is_alias_creation_allowed(
+ user_id="@test:example.com",
+ alias="#unofficial_st:example.com",
+ ))
+
+ self.assertTrue(rd_config.is_alias_creation_allowed(
+ user_id="@foobar:example.com",
+ alias="#test:example.com",
+ ))
+
+ self.assertTrue(rd_config.is_alias_creation_allowed(
+ user_id="@gah:example.com",
+ alias="#goo:example.com",
+ ))
+
+ self.assertFalse(rd_config.is_alias_creation_allowed(
+ user_id="@test:example.com",
+ alias="#test:example.com",
+ ))
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 8299dc72c8..d643bec887 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -63,6 +63,14 @@ class KeyringTestCase(unittest.TestCase):
keys = self.mock_perspective_server.get_verify_keys()
self.hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
+ def assert_sentinel_context(self):
+ if LoggingContext.current_context() != LoggingContext.sentinel:
+ self.fail(
+ "Expected sentinel context but got %s" % (
+ LoggingContext.current_context(),
+ )
+ )
+
def check_context(self, _, expected):
self.assertEquals(
getattr(LoggingContext.current_context(), "request", None), expected
@@ -70,8 +78,6 @@ class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_wait_for_previous_lookups(self):
- sentinel_context = LoggingContext.current_context()
-
kr = keyring.Keyring(self.hs)
lookup_1_deferred = defer.Deferred()
@@ -99,8 +105,10 @@ class KeyringTestCase(unittest.TestCase):
["server1"], {"server1": lookup_2_deferred}
)
self.assertFalse(wait_2_deferred.called)
+
# ... so we should have reset the LoggingContext.
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assert_sentinel_context()
+
wait_2_deferred.addBoth(self.check_context, "two")
# let the first lookup complete (in the sentinel context)
@@ -198,8 +206,6 @@ class KeyringTestCase(unittest.TestCase):
json1 = {}
signedjson.sign.sign_json(json1, "server9", key1)
- sentinel_context = LoggingContext.current_context()
-
with LoggingContext("one") as context_one:
context_one.request = "one"
@@ -213,7 +219,7 @@ class KeyringTestCase(unittest.TestCase):
defer = kr.verify_json_for_server("server9", json1)
self.assertFalse(defer.called)
- self.assertIs(LoggingContext.current_context(), sentinel_context)
+ self.assert_sentinel_context()
yield defer
self.assertIs(LoggingContext.current_context(), context_one)
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index ff217ca8b9..d0cc492deb 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -156,7 +156,7 @@ class SerializeEventTestCase(unittest.TestCase):
room_id="!foo:bar",
content={"key.with.dots": {}},
),
- ["content.key\.with\.dots"],
+ [r"content.key\.with\.dots"],
),
{"content": {"key.with.dots": {}}},
)
@@ -172,7 +172,7 @@ class SerializeEventTestCase(unittest.TestCase):
"nested.dot.key": {"leaf.key": 42, "not_me_either": 1},
},
),
- ["content.nested\.dot\.key.leaf\.key"],
+ [r"content.nested\.dot\.key.leaf\.key"],
),
{"content": {"nested.dot.key": {"leaf.key": 42}}},
)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index ec7355688b..8ae6556c0a 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -18,7 +18,9 @@ from mock import Mock
from twisted.internet import defer
+from synapse.config.room_directory import RoomDirectoryConfig
from synapse.handlers.directory import DirectoryHandler
+from synapse.rest.client.v1 import directory, room
from synapse.types import RoomAlias
from tests import unittest
@@ -102,3 +104,49 @@ class DirectoryTestCase(unittest.TestCase):
)
self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
+
+
+class TestCreateAliasACL(unittest.HomeserverTestCase):
+ user_id = "@test:test"
+
+ servlets = [directory.register_servlets, room.register_servlets]
+
+ def prepare(self, hs, reactor, clock):
+ # We cheekily override the config to add custom alias creation rules
+ config = {}
+ config["alias_creation_rules"] = [
+ {
+ "user_id": "*",
+ "alias": "#unofficial_*",
+ "action": "allow",
+ }
+ ]
+
+ rd_config = RoomDirectoryConfig()
+ rd_config.read_config(config)
+
+ self.hs.config.is_alias_creation_allowed = rd_config.is_alias_creation_allowed
+
+ return hs
+
+ def test_denied(self):
+ room_id = self.helper.create_room_as(self.user_id)
+
+ request, channel = self.make_request(
+ "PUT",
+ b"directory/room/%23test%3Atest",
+ ('{"room_id":"%s"}' % (room_id,)).encode('ascii'),
+ )
+ self.render(request)
+ self.assertEquals(403, channel.code, channel.result)
+
+ def test_allowed(self):
+ room_id = self.helper.create_room_as(self.user_id)
+
+ request, channel = self.make_request(
+ "PUT",
+ b"directory/room/%23unofficial_test%3Atest",
+ ('{"room_id":"%s"}' % (room_id,)).encode('ascii'),
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
new file mode 100644
index 0000000000..c8994f416e
--- /dev/null
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -0,0 +1,392 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+
+import mock
+
+from twisted.internet import defer
+
+import synapse.api.errors
+import synapse.handlers.e2e_room_keys
+import synapse.storage
+from synapse.api import errors
+
+from tests import unittest, utils
+
+# sample room_key data for use in the tests
+room_keys = {
+ "rooms": {
+ "!abc:matrix.org": {
+ "sessions": {
+ "c0ff33": {
+ "first_message_index": 1,
+ "forwarded_count": 1,
+ "is_verified": False,
+ "session_data": "SSBBTSBBIEZJU0gK"
+ }
+ }
+ }
+ }
+}
+
+
+class E2eRoomKeysHandlerTestCase(unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs)
+ self.hs = None # type: synapse.server.HomeServer
+ self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield utils.setup_test_homeserver(
+ self.addCleanup,
+ handlers=None,
+ replication_layer=mock.Mock(),
+ )
+ self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
+ self.local_user = "@boris:" + self.hs.hostname
+
+ @defer.inlineCallbacks
+ def test_get_missing_current_version_info(self):
+ """Check that we get a 404 if we ask for info about the current version
+ if there is no version.
+ """
+ res = None
+ try:
+ yield self.handler.get_version_info(self.local_user)
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ @defer.inlineCallbacks
+ def test_get_missing_version_info(self):
+ """Check that we get a 404 if we ask for info about a specific version
+ if it doesn't exist.
+ """
+ res = None
+ try:
+ yield self.handler.get_version_info(self.local_user, "bogus_version")
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ @defer.inlineCallbacks
+ def test_create_version(self):
+ """Check that we can create and then retrieve versions.
+ """
+ res = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(res, "1")
+
+ # check we can retrieve it as the current version
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertDictEqual(res, {
+ "version": "1",
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+
+ # check we can retrieve it as a specific version
+ res = yield self.handler.get_version_info(self.local_user, "1")
+ self.assertDictEqual(res, {
+ "version": "1",
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+
+ # upload a new one...
+ res = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ })
+ self.assertEqual(res, "2")
+
+ # check we can retrieve it as the current version
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertDictEqual(res, {
+ "version": "2",
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ })
+
+ @defer.inlineCallbacks
+ def test_delete_missing_version(self):
+ """Check that we get a 404 on deleting nonexistent versions
+ """
+ res = None
+ try:
+ yield self.handler.delete_version(self.local_user, "1")
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ @defer.inlineCallbacks
+ def test_delete_missing_current_version(self):
+ """Check that we get a 404 on deleting nonexistent current version
+ """
+ res = None
+ try:
+ yield self.handler.delete_version(self.local_user)
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ @defer.inlineCallbacks
+ def test_delete_version(self):
+ """Check that we can create and then delete versions.
+ """
+ res = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(res, "1")
+
+ # check we can delete it
+ yield self.handler.delete_version(self.local_user, "1")
+
+ # check that it's gone
+ res = None
+ try:
+ yield self.handler.get_version_info(self.local_user, "1")
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ @defer.inlineCallbacks
+ def test_get_missing_backup(self):
+ """Check that we get a 404 on querying missing backup
+ """
+ res = None
+ try:
+ yield self.handler.get_room_keys(self.local_user, "bogus_version")
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ @defer.inlineCallbacks
+ def test_get_missing_room_keys(self):
+ """Check we get an empty response from an empty backup
+ """
+ version = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(version, "1")
+
+ res = yield self.handler.get_room_keys(self.local_user, version)
+ self.assertDictEqual(res, {
+ "rooms": {}
+ })
+
+ # TODO: test the locking semantics when uploading room_keys,
+ # although this is probably best done in sytest
+
+ @defer.inlineCallbacks
+ def test_upload_room_keys_no_versions(self):
+ """Check that we get a 404 on uploading keys when no versions are defined
+ """
+ res = None
+ try:
+ yield self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ @defer.inlineCallbacks
+ def test_upload_room_keys_bogus_version(self):
+ """Check that we get a 404 on uploading keys when an nonexistent version
+ is specified
+ """
+ version = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(version, "1")
+
+ res = None
+ try:
+ yield self.handler.upload_room_keys(
+ self.local_user, "bogus_version", room_keys
+ )
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ @defer.inlineCallbacks
+ def test_upload_room_keys_wrong_version(self):
+ """Check that we get a 403 on uploading keys for an old version
+ """
+ version = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(version, "1")
+
+ version = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ })
+ self.assertEqual(version, "2")
+
+ res = None
+ try:
+ yield self.handler.upload_room_keys(self.local_user, "1", room_keys)
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 403)
+
+ @defer.inlineCallbacks
+ def test_upload_room_keys_insert(self):
+ """Check that we can insert and retrieve keys for a session
+ """
+ version = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(version, "1")
+
+ yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+
+ res = yield self.handler.get_room_keys(self.local_user, version)
+ self.assertDictEqual(res, room_keys)
+
+ # check getting room_keys for a given room
+ res = yield self.handler.get_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org"
+ )
+ self.assertDictEqual(res, room_keys)
+
+ # check getting room_keys for a given session_id
+ res = yield self.handler.get_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org",
+ session_id="c0ff33",
+ )
+ self.assertDictEqual(res, room_keys)
+
+ @defer.inlineCallbacks
+ def test_upload_room_keys_merge(self):
+ """Check that we can upload a new room_key for an existing session and
+ have it correctly merged"""
+ version = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(version, "1")
+
+ yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+
+ new_room_keys = copy.deepcopy(room_keys)
+ new_room_key = new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33']
+
+ # test that increasing the message_index doesn't replace the existing session
+ new_room_key['first_message_index'] = 2
+ new_room_key['session_data'] = 'new'
+ yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+
+ res = yield self.handler.get_room_keys(self.local_user, version)
+ self.assertEqual(
+ res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'],
+ "SSBBTSBBIEZJU0gK"
+ )
+
+ # test that marking the session as verified however /does/ replace it
+ new_room_key['is_verified'] = True
+ yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+
+ res = yield self.handler.get_room_keys(self.local_user, version)
+ self.assertEqual(
+ res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'],
+ "new"
+ )
+
+ # test that a session with a higher forwarded_count doesn't replace one
+ # with a lower forwarding count
+ new_room_key['forwarded_count'] = 2
+ new_room_key['session_data'] = 'other'
+ yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+
+ res = yield self.handler.get_room_keys(self.local_user, version)
+ self.assertEqual(
+ res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'],
+ "new"
+ )
+
+ # TODO: check edge cases as well as the common variations here
+
+ @defer.inlineCallbacks
+ def test_delete_room_keys(self):
+ """Check that we can insert and delete keys for a session
+ """
+ version = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(version, "1")
+
+ # check for bulk-delete
+ yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ yield self.handler.delete_room_keys(self.local_user, version)
+ res = yield self.handler.get_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org",
+ session_id="c0ff33",
+ )
+ self.assertDictEqual(res, {
+ "rooms": {}
+ })
+
+ # check for bulk-delete per room
+ yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ yield self.handler.delete_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org",
+ )
+ res = yield self.handler.get_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org",
+ session_id="c0ff33",
+ )
+ self.assertDictEqual(res, {
+ "rooms": {}
+ })
+
+ # check for bulk-delete per session
+ yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ yield self.handler.delete_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org",
+ session_id="c0ff33",
+ )
+ res = yield self.handler.get_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org",
+ session_id="c0ff33",
+ )
+ self.assertDictEqual(res, {
+ "rooms": {}
+ })
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 7b4ade3dfb..90a2a76475 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api.errors import ResourceLimitError
from synapse.handlers.register import RegistrationHandler
-from synapse.types import UserID, create_requester
+from synapse.types import RoomAlias, UserID, create_requester
from tests.utils import setup_test_homeserver
@@ -41,30 +41,27 @@ class RegistrationTestCase(unittest.TestCase):
self.mock_captcha_client = Mock()
self.hs = yield setup_test_homeserver(
self.addCleanup,
- handlers=None,
- http_client=None,
expire_access_token=True,
- profile_handler=Mock(),
)
self.macaroon_generator = Mock(
generate_access_token=Mock(return_value='secret')
)
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
- self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler
self.store = self.hs.get_datastore()
self.hs.config.max_mau_value = 50
self.lots_of_users = 100
self.small_number_of_users = 1
+ self.requester = create_requester("@requester:test")
+
@defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
- local_part = "someone"
- display_name = "someone"
- user_id = "@someone:test"
- requester = create_requester("@as:test")
+ frank = UserID.from_string("@frank:test")
+ user_id = frank.to_string()
+ requester = create_requester(user_id)
result_user_id, result_token = yield self.handler.get_or_create_user(
- requester, local_part, display_name
+ requester, frank.localpart, "Frankie"
)
self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret')
@@ -78,12 +75,11 @@ class RegistrationTestCase(unittest.TestCase):
token="jkv;g498752-43gj['eamb!-5",
password_hash=None,
)
- local_part = "frank"
- display_name = "Frank"
- user_id = "@frank:test"
- requester = create_requester("@as:test")
+ local_part = frank.localpart
+ user_id = frank.to_string()
+ requester = create_requester(user_id)
result_user_id, result_token = yield self.handler.get_or_create_user(
- requester, local_part, display_name
+ requester, local_part, None
)
self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret')
@@ -92,7 +88,7 @@ class RegistrationTestCase(unittest.TestCase):
def test_mau_limits_when_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
- yield self.handler.get_or_create_user("requester", 'a', "display_name")
+ yield self.handler.get_or_create_user(self.requester, 'a', "display_name")
@defer.inlineCallbacks
def test_get_or_create_user_mau_not_blocked(self):
@@ -101,7 +97,7 @@ class RegistrationTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.config.max_mau_value - 1)
)
# Ensure does not throw exception
- yield self.handler.get_or_create_user("@user:server", 'c', "User")
+ yield self.handler.get_or_create_user(self.requester, 'c', "User")
@defer.inlineCallbacks
def test_get_or_create_user_mau_blocked(self):
@@ -110,13 +106,13 @@ class RegistrationTestCase(unittest.TestCase):
return_value=defer.succeed(self.lots_of_users)
)
with self.assertRaises(ResourceLimitError):
- yield self.handler.get_or_create_user("requester", 'b', "display_name")
+ yield self.handler.get_or_create_user(self.requester, 'b', "display_name")
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(ResourceLimitError):
- yield self.handler.get_or_create_user("requester", 'b', "display_name")
+ yield self.handler.get_or_create_user(self.requester, 'b', "display_name")
@defer.inlineCallbacks
def test_register_mau_blocked(self):
@@ -147,3 +143,54 @@ class RegistrationTestCase(unittest.TestCase):
)
with self.assertRaises(ResourceLimitError):
yield self.handler.register_saml2(localpart="local_part")
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_rooms(self):
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+ res = yield self.handler.register(localpart='jeff')
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ room_id = yield directory_handler.get_association(room_alias)
+
+ self.assertTrue(room_id['room_id'] in rooms)
+ self.assertEqual(len(rooms), 1)
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_rooms_with_no_rooms(self):
+ self.hs.config.auto_join_rooms = []
+ frank = UserID.from_string("@frank:test")
+ res = yield self.handler.register(frank.localpart)
+ self.assertEqual(res[0], frank.to_string())
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ self.assertEqual(len(rooms), 0)
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_where_room_is_another_domain(self):
+ self.hs.config.auto_join_rooms = ["#room:another"]
+ frank = UserID.from_string("@frank:test")
+ res = yield self.handler.register(frank.localpart)
+ self.assertEqual(res[0], frank.to_string())
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ self.assertEqual(len(rooms), 0)
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_where_auto_create_is_false(self):
+ self.hs.config.autocreate_auto_join_rooms = False
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+ res = yield self.handler.register(localpart='jeff')
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ self.assertEqual(len(rooms), 0)
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_where_no_consent(self):
+ self.hs.config.user_consent_at_registration = True
+ self.hs.config.block_events_without_consent_error = "Error"
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+ res = yield self.handler.register(localpart='jeff')
+ yield self.handler.post_consent_actions(res[0])
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ self.assertEqual(len(rooms), 0)
diff --git a/tests/handlers/test_roomlist.py b/tests/handlers/test_roomlist.py
new file mode 100644
index 0000000000..61eebb6985
--- /dev/null
+++ b/tests/handlers/test_roomlist.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.handlers.room_list import RoomListNextBatch
+
+import tests.unittest
+import tests.utils
+
+
+class RoomListTestCase(tests.unittest.TestCase):
+ """ Tests RoomList's RoomListNextBatch. """
+
+ def setUp(self):
+ pass
+
+ def test_check_read_batch_tokens(self):
+ batch_token = RoomListNextBatch(
+ stream_ordering="abcdef",
+ public_room_stream_id="123",
+ current_limit=20,
+ direction_is_forward=True,
+ ).to_token()
+ next_batch = RoomListNextBatch.from_token(batch_token)
+ self.assertEquals(next_batch.stream_ordering, "abcdef")
+ self.assertEquals(next_batch.public_room_stream_id, "123")
+ self.assertEquals(next_batch.current_limit, 20)
+ self.assertEquals(next_batch.direction_is_forward, True)
diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py
new file mode 100644
index 0000000000..0f613945c8
--- /dev/null
+++ b/tests/patch_inline_callbacks.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import functools
+import sys
+
+from twisted.internet import defer
+from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
+
+
+def do_patch():
+ """
+ Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
+ """
+
+ from synapse.util.logcontext import LoggingContext
+
+ orig_inline_callbacks = defer.inlineCallbacks
+
+ def new_inline_callbacks(f):
+
+ orig = orig_inline_callbacks(f)
+
+ @functools.wraps(f)
+ def wrapped(*args, **kwargs):
+ start_context = LoggingContext.current_context()
+
+ try:
+ res = orig(*args, **kwargs)
+ except Exception:
+ if LoggingContext.current_context() != start_context:
+ err = "%s changed context from %s to %s on exception" % (
+ f, start_context, LoggingContext.current_context()
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ raise
+
+ if not isinstance(res, Deferred) or res.called:
+ if LoggingContext.current_context() != start_context:
+ err = "%s changed context from %s to %s" % (
+ f, start_context, LoggingContext.current_context()
+ )
+ # print the error to stderr because otherwise all we
+ # see in travis-ci is the 500 error
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return res
+
+ if LoggingContext.current_context() != LoggingContext.sentinel:
+ err = (
+ "%s returned incomplete deferred in non-sentinel context "
+ "%s (start was %s)"
+ ) % (
+ f, LoggingContext.current_context(), start_context,
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+
+ def check_ctx(r):
+ if LoggingContext.current_context() != start_context:
+ err = "%s completion of %s changed context from %s to %s" % (
+ "Failure" if isinstance(r, Failure) else "Success",
+ f, start_context, LoggingContext.current_context(),
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return r
+
+ res.addBoth(check_ctx)
+ return res
+
+ return wrapped
+
+ defer.inlineCallbacks = new_inline_callbacks
diff --git a/tests/push/__init__.py b/tests/push/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/push/__init__.py
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
new file mode 100644
index 0000000000..50ee6910d1
--- /dev/null
+++ b/tests/push/test_email.py
@@ -0,0 +1,148 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import pkg_resources
+
+from twisted.internet.defer import Deferred
+
+from synapse.rest.client.v1 import admin, login, room
+
+from tests.unittest import HomeserverTestCase
+
+try:
+ from synapse.push.mailer import load_jinja2_templates
+except Exception:
+ load_jinja2_templates = None
+
+
+class EmailPusherTests(HomeserverTestCase):
+
+ skip = "No Jinja installed" if not load_jinja2_templates else None
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+ user_id = True
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+
+ # List[Tuple[Deferred, args, kwargs]]
+ self.email_attempts = []
+
+ def sendmail(*args, **kwargs):
+ d = Deferred()
+ self.email_attempts.append((d, args, kwargs))
+ return d
+
+ config = self.default_config()
+ config.email_enable_notifs = True
+ config.start_pushers = True
+
+ config.email_template_dir = os.path.abspath(
+ pkg_resources.resource_filename('synapse', 'res/templates')
+ )
+ config.email_notif_template_html = "notif_mail.html"
+ config.email_notif_template_text = "notif_mail.txt"
+ config.email_smtp_host = "127.0.0.1"
+ config.email_smtp_port = 20
+ config.require_transport_security = False
+ config.email_smtp_user = None
+ config.email_app_name = "Matrix"
+ config.email_notif_from = "test@example.com"
+
+ hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
+
+ return hs
+
+ def test_sends_email(self):
+
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("otheruser", "pass")
+ other_access_token = self.login("otheruser", "pass")
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_tuple["token_id"]
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="email",
+ app_id="m.email",
+ app_display_name="Email Notifications",
+ device_display_name="a@example.com",
+ pushkey="a@example.com",
+ lang=None,
+ data={},
+ )
+ )
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # Invite the other person
+ self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id)
+
+ # The other user joins
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # The other user sends some messages
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.helper.send(room, body="There!", tok=other_access_token)
+
+ # Get the stream ordering before it gets sent
+ pushers = self.get_success(
+ self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ )
+ self.assertEqual(len(pushers), 1)
+ last_stream_ordering = pushers[0]["last_stream_ordering"]
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump(100)
+
+ # It hasn't succeeded yet, so the stream ordering shouldn't have moved
+ pushers = self.get_success(
+ self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ )
+ self.assertEqual(len(pushers), 1)
+ self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
+
+ # One email was attempted to be sent
+ self.assertEqual(len(self.email_attempts), 1)
+
+ # Make the email succeed
+ self.email_attempts[0][0].callback(True)
+ self.pump()
+
+ # One email was attempted to be sent
+ self.assertEqual(len(self.email_attempts), 1)
+
+ # The stream ordering has increased
+ pushers = self.get_success(
+ self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ )
+ self.assertEqual(len(pushers), 1)
+ self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
new file mode 100644
index 0000000000..6dc45e8506
--- /dev/null
+++ b/tests/push/test_http.py
@@ -0,0 +1,160 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet.defer import Deferred
+
+from synapse.rest.client.v1 import admin, login, room
+from synapse.util.logcontext import make_deferred_yieldable
+
+from tests.unittest import HomeserverTestCase
+
+try:
+ from synapse.push.mailer import load_jinja2_templates
+except Exception:
+ load_jinja2_templates = None
+
+
+class HTTPPusherTests(HomeserverTestCase):
+
+ skip = "No Jinja installed" if not load_jinja2_templates else None
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+ user_id = True
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+
+ self.push_attempts = []
+
+ m = Mock()
+
+ def post_json_get_json(url, body):
+ d = Deferred()
+ self.push_attempts.append((d, url, body))
+ return make_deferred_yieldable(d)
+
+ m.post_json_get_json = post_json_get_json
+
+ config = self.default_config()
+ config.start_pushers = True
+
+ hs = self.setup_test_homeserver(config=config, simple_http_client=m)
+
+ return hs
+
+ def test_sends_http(self):
+ """
+ The HTTP pusher will send pushes for each message to a HTTP endpoint
+ when configured to do so.
+ """
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("otheruser", "pass")
+ other_access_token = self.login("otheruser", "pass")
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_tuple["token_id"]
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "example.com"},
+ )
+ )
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # Invite the other person
+ self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id)
+
+ # The other user joins
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # The other user sends some messages
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.helper.send(room, body="There!", tok=other_access_token)
+
+ # Get the stream ordering before it gets sent
+ pushers = self.get_success(
+ self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ )
+ self.assertEqual(len(pushers), 1)
+ last_stream_ordering = pushers[0]["last_stream_ordering"]
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ # It hasn't succeeded yet, so the stream ordering shouldn't have moved
+ pushers = self.get_success(
+ self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ )
+ self.assertEqual(len(pushers), 1)
+ self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
+
+ # One push was attempted to be sent -- it'll be the first message
+ self.assertEqual(len(self.push_attempts), 1)
+ self.assertEqual(self.push_attempts[0][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[0][2]["notification"]["content"]["body"], "Hi!"
+ )
+
+ # Make the push succeed
+ self.push_attempts[0][0].callback({})
+ self.pump()
+
+ # The stream ordering has increased
+ pushers = self.get_success(
+ self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ )
+ self.assertEqual(len(pushers), 1)
+ self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
+ last_stream_ordering = pushers[0]["last_stream_ordering"]
+
+ # Now it'll try and send the second push message, which will be the second one
+ self.assertEqual(len(self.push_attempts), 2)
+ self.assertEqual(self.push_attempts[1][1], "example.com")
+ self.assertEqual(
+ self.push_attempts[1][2]["notification"]["content"]["body"], "There!"
+ )
+
+ # Make the second push succeed
+ self.push_attempts[1][0].callback({})
+ self.pump()
+
+ # The stream ordering has increased, again
+ pushers = self.get_success(
+ self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
+ )
+ self.assertEqual(len(pushers), 1)
+ self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 41be5d5a1a..1688a741d1 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -28,8 +28,8 @@ ROOM_ID = "!room:blue"
def dict_equals(self, other):
- me = encode_canonical_json(self._event_dict)
- them = encode_canonical_json(other._event_dict)
+ me = encode_canonical_json(self.get_pdu_json())
+ them = encode_canonical_json(other.get_pdu_json())
return me == them
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
new file mode 100644
index 0000000000..4294bbec2a
--- /dev/null
+++ b/tests/rest/client/test_consent.py
@@ -0,0 +1,118 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from synapse.api.urls import ConsentURIBuilder
+from synapse.rest.client.v1 import admin, login, room
+from synapse.rest.consent import consent_resource
+
+from tests import unittest
+from tests.server import render
+
+try:
+ from synapse.push.mailer import load_jinja2_templates
+except Exception:
+ load_jinja2_templates = None
+
+
+class ConsentResourceTestCase(unittest.HomeserverTestCase):
+ skip = "No Jinja installed" if not load_jinja2_templates else None
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+ user_id = True
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+
+ config = self.default_config()
+ config.user_consent_version = "1"
+ config.public_baseurl = ""
+ config.form_secret = "123abc"
+
+ # Make some temporary templates...
+ temp_consent_path = self.mktemp()
+ os.mkdir(temp_consent_path)
+ os.mkdir(os.path.join(temp_consent_path, 'en'))
+ config.user_consent_template_dir = os.path.abspath(temp_consent_path)
+
+ with open(os.path.join(temp_consent_path, "en/1.html"), 'w') as f:
+ f.write("{{version}},{{has_consented}}")
+
+ with open(os.path.join(temp_consent_path, "en/success.html"), 'w') as f:
+ f.write("yay!")
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def test_render_public_consent(self):
+ """You can observe the terms form without specifying a user"""
+ resource = consent_resource.ConsentResource(self.hs)
+ request, channel = self.make_request("GET", "/consent?v=1", shorthand=False)
+ render(request, resource, self.reactor)
+ self.assertEqual(channel.code, 200)
+
+ def test_accept_consent(self):
+ """
+ A user can use the consent form to accept the terms.
+ """
+ uri_builder = ConsentURIBuilder(self.hs.config)
+ resource = consent_resource.ConsentResource(self.hs)
+
+ # Register a user
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Fetch the consent page, to get the consent version
+ consent_uri = (
+ uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "")
+ + "&u=user"
+ )
+ request, channel = self.make_request(
+ "GET", consent_uri, access_token=access_token, shorthand=False
+ )
+ render(request, resource, self.reactor)
+ self.assertEqual(channel.code, 200)
+
+ # Get the version from the body, and whether we've consented
+ version, consented = channel.result["body"].decode('ascii').split(",")
+ self.assertEqual(consented, "False")
+
+ # POST to the consent page, saying we've agreed
+ request, channel = self.make_request(
+ "POST",
+ consent_uri + "&v=" + version,
+ access_token=access_token,
+ shorthand=False,
+ )
+ render(request, resource, self.reactor)
+ self.assertEqual(channel.code, 200)
+
+ # Fetch the consent page, to get the consent version -- it should have
+ # changed
+ request, channel = self.make_request(
+ "GET", consent_uri, access_token=access_token, shorthand=False
+ )
+ render(request, resource, self.reactor)
+ self.assertEqual(channel.code, 200)
+
+ # Get the version from the body, and check that it's the version we
+ # agreed to, and that we've consented to it.
+ version, consented = channel.result["body"].decode('ascii').split(",")
+ self.assertEqual(consented, "True")
+ self.assertEqual(version, "1")
diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py
index 1a553fa3f9..e38eb628a9 100644
--- a/tests/rest/client/v1/test_admin.py
+++ b/tests/rest/client/v1/test_admin.py
@@ -19,24 +19,17 @@ import json
from mock import Mock
-from synapse.http.server import JsonResource
from synapse.rest.client.v1.admin import register_servlets
-from synapse.util import Clock
from tests import unittest
-from tests.server import (
- ThreadedMemoryReactorClock,
- make_request,
- render,
- setup_test_homeserver,
-)
-class UserRegisterTestCase(unittest.TestCase):
- def setUp(self):
+class UserRegisterTestCase(unittest.HomeserverTestCase):
+
+ servlets = [register_servlets]
+
+ def make_homeserver(self, reactor, clock):
- self.clock = ThreadedMemoryReactorClock()
- self.hs_clock = Clock(self.clock)
self.url = "/_matrix/client/r0/admin/register"
self.registration_handler = Mock()
@@ -50,17 +43,14 @@ class UserRegisterTestCase(unittest.TestCase):
self.secrets = Mock()
- self.hs = setup_test_homeserver(
- self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
- )
+ self.hs = self.setup_test_homeserver()
self.hs.config.registration_shared_secret = u"shared"
self.hs.get_media_repository = Mock()
self.hs.get_deactivate_account_handler = Mock()
- self.resource = JsonResource(self.hs)
- register_servlets(self.hs, self.resource)
+ return self.hs
def test_disabled(self):
"""
@@ -69,8 +59,8 @@ class UserRegisterTestCase(unittest.TestCase):
"""
self.hs.config.registration_shared_secret = None
- request, channel = make_request("POST", self.url, b'{}')
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, b'{}')
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
@@ -87,8 +77,8 @@ class UserRegisterTestCase(unittest.TestCase):
self.hs.get_secrets = Mock(return_value=secrets)
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
self.assertEqual(channel.json_body, {"nonce": "abcd"})
@@ -97,25 +87,25 @@ class UserRegisterTestCase(unittest.TestCase):
Calling GET on the endpoint will return a randomised nonce, which will
only last for SALT_TIMEOUT (60s).
"""
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
nonce = channel.json_body["nonce"]
# 59 seconds
- self.clock.advance(59)
+ self.reactor.advance(59)
body = json.dumps({"nonce": nonce})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"])
# 61 seconds
- self.clock.advance(2)
+ self.reactor.advance(2)
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"])
@@ -124,8 +114,8 @@ class UserRegisterTestCase(unittest.TestCase):
"""
Only the provided nonce can be used, as it's checked in the MAC.
"""
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -141,8 +131,8 @@ class UserRegisterTestCase(unittest.TestCase):
"mac": want_mac,
}
)
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("HMAC incorrect", channel.json_body["error"])
@@ -152,8 +142,8 @@ class UserRegisterTestCase(unittest.TestCase):
When the correct nonce is provided, and the right key is provided, the
user is registered.
"""
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -169,8 +159,8 @@ class UserRegisterTestCase(unittest.TestCase):
"mac": want_mac,
}
)
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -179,8 +169,8 @@ class UserRegisterTestCase(unittest.TestCase):
"""
A valid unrecognised nonce.
"""
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -196,15 +186,15 @@ class UserRegisterTestCase(unittest.TestCase):
"mac": want_mac,
}
)
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"])
@@ -217,8 +207,8 @@ class UserRegisterTestCase(unittest.TestCase):
"""
def nonce():
- request, channel = make_request("GET", self.url)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("GET", self.url)
+ self.render(request)
return channel.json_body["nonce"]
#
@@ -227,8 +217,8 @@ class UserRegisterTestCase(unittest.TestCase):
# Must be present
body = json.dumps({})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('nonce must be specified', channel.json_body["error"])
@@ -239,32 +229,32 @@ class UserRegisterTestCase(unittest.TestCase):
# Must be present
body = json.dumps({"nonce": nonce()})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": 1234})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
@@ -275,16 +265,16 @@ class UserRegisterTestCase(unittest.TestCase):
# Must be present
body = json.dumps({"nonce": nonce(), "username": "a"})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('password must be specified', channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
@@ -293,16 +283,16 @@ class UserRegisterTestCase(unittest.TestCase):
body = json.dumps(
{"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}
)
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
# Super long
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
- request, channel = make_request("POST", self.url, body.encode('utf8'))
- render(request, self.resource, self.clock)
+ request, channel = self.make_request("POST", self.url, body.encode('utf8'))
+ self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
index 6b7ff813d5..f973eff8cf 100644
--- a/tests/rest/client/v1/test_register.py
+++ b/tests/rest/client/v1/test_register.py
@@ -45,11 +45,11 @@ class CreateUserServletTestCase(unittest.TestCase):
)
handlers = Mock(registration_handler=self.registration_handler)
- self.clock = MemoryReactorClock()
- self.hs_clock = Clock(self.clock)
+ self.reactor = MemoryReactorClock()
+ self.hs_clock = Clock(self.reactor)
self.hs = self.hs = setup_test_homeserver(
- self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
+ self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
)
self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.get_handlers = Mock(return_value=handlers)
@@ -76,8 +76,8 @@ class CreateUserServletTestCase(unittest.TestCase):
return_value=(user_id, token)
)
- request, channel = make_request(b"POST", url, request_data)
- render(request, res, self.clock)
+ request, channel = make_request(self.reactor, b"POST", url, request_data)
+ render(request, res, self.reactor)
self.assertEquals(channel.result["code"], b"200")
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 359f7777ff..a824be9a62 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -23,7 +23,7 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer
from synapse.api.constants import Membership
-from synapse.rest.client.v1 import room
+from synapse.rest.client.v1 import admin, login, room
from tests import unittest
@@ -799,3 +799,107 @@ class RoomMessageListTestCase(RoomBase):
self.assertEquals(token, channel.json_body['start'])
self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body)
+
+
+class RoomSearchTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+ user_id = True
+ hijack_auth = False
+
+ def prepare(self, reactor, clock, hs):
+
+ # Register the user who does the searching
+ self.user_id = self.register_user("user", "pass")
+ self.access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ # Create a room
+ self.room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+
+ # Invite the other person
+ self.helper.invite(
+ room=self.room,
+ src=self.user_id,
+ tok=self.access_token,
+ targ=self.other_user_id,
+ )
+
+ # The other user joins
+ self.helper.join(
+ room=self.room, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ def test_finds_message(self):
+ """
+ The search functionality will search for content in messages if asked to
+ do so.
+ """
+ # The other user sends some messages
+ self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
+ self.helper.send(self.room, body="There!", tok=self.other_access_token)
+
+ request, channel = self.make_request(
+ "POST",
+ "/search?access_token=%s" % (self.access_token,),
+ {
+ "search_categories": {
+ "room_events": {"keys": ["content.body"], "search_term": "Hi"}
+ }
+ },
+ )
+ self.render(request)
+
+ # Check we get the results we expect -- one search result, of the sent
+ # messages
+ self.assertEqual(channel.code, 200)
+ results = channel.json_body["search_categories"]["room_events"]
+ self.assertEqual(results["count"], 1)
+ self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!")
+
+ # No context was requested, so we should get none.
+ self.assertEqual(results["results"][0]["context"], {})
+
+ def test_include_context(self):
+ """
+ When event_context includes include_profile, profile information will be
+ included in the search response.
+ """
+ # The other user sends some messages
+ self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
+ self.helper.send(self.room, body="There!", tok=self.other_access_token)
+
+ request, channel = self.make_request(
+ "POST",
+ "/search?access_token=%s" % (self.access_token,),
+ {
+ "search_categories": {
+ "room_events": {
+ "keys": ["content.body"],
+ "search_term": "Hi",
+ "event_context": {"include_profile": True},
+ }
+ }
+ },
+ )
+ self.render(request)
+
+ # Check we get the results we expect -- one search result, of the sent
+ # messages
+ self.assertEqual(channel.code, 200)
+ results = channel.json_body["search_categories"]["room_events"]
+ self.assertEqual(results["count"], 1)
+ self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!")
+
+ # We should get context info, like the two users, and the display names.
+ context = results["results"][0]["context"]
+ self.assertEqual(len(context["profile_info"].keys()), 2)
+ self.assertEqual(
+ context["profile_info"][self.other_user_id]["displayname"], "otheruser"
+ )
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 530dc8ba6d..9c401bf300 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -169,7 +169,7 @@ class RestHelper(object):
path = path + "?access_token=%s" % tok
request, channel = make_request(
- "POST", path, json.dumps(content).encode('utf8')
+ self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8')
)
render(request, self.resource, self.hs.get_reactor())
@@ -217,7 +217,9 @@ class RestHelper(object):
data = {"membership": membership}
- request, channel = make_request("PUT", path, json.dumps(data).encode('utf8'))
+ request, channel = make_request(
+ self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8')
+ )
render(request, self.resource, self.hs.get_reactor())
@@ -228,18 +230,6 @@ class RestHelper(object):
self.auth_user_id = temp_id
- @defer.inlineCallbacks
- def register(self, user_id):
- (code, response) = yield self.mock_resource.trigger(
- "POST",
- "/_matrix/client/r0/register",
- json.dumps(
- {"user": user_id, "password": "test", "type": "m.login.password"}
- ),
- )
- self.assertEquals(200, code)
- defer.returnValue(response)
-
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -251,7 +241,9 @@ class RestHelper(object):
if tok:
path = path + "?access_token=%s" % tok
- request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
+ request, channel = make_request(
+ self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8')
+ )
render(request, self.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
new file mode 100644
index 0000000000..7fa120a10f
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet.defer import succeed
+
+from synapse.api.constants import LoginType
+from synapse.rest.client.v1 import admin
+from synapse.rest.client.v2_alpha import auth, register
+
+from tests import unittest
+
+
+class FallbackAuthTests(unittest.HomeserverTestCase):
+
+ servlets = [
+ auth.register_servlets,
+ admin.register_servlets,
+ register.register_servlets,
+ ]
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+
+ config = self.default_config()
+
+ config.enable_registration_captcha = True
+ config.recaptcha_public_key = "brokencake"
+ config.registrations_require_3pid = []
+
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ auth_handler = hs.get_auth_handler()
+
+ self.recaptcha_attempts = []
+
+ def _recaptcha(authdict, clientip):
+ self.recaptcha_attempts.append((authdict, clientip))
+ return succeed(True)
+
+ auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha
+
+ @unittest.INFO
+ def test_fallback_captcha(self):
+
+ request, channel = self.make_request(
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ )
+ self.render(request)
+
+ # Returns a 401 as per the spec
+ self.assertEqual(request.code, 401)
+ # Grab the session
+ session = channel.json_body["session"]
+ # Assert our configured public key is being given
+ self.assertEqual(
+ channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
+ )
+
+ request, channel = self.make_request(
+ "GET", "auth/m.login.recaptcha/fallback/web?session=" + session
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
+
+ request, channel = self.make_request(
+ "POST",
+ "auth/m.login.recaptcha/fallback/web?session="
+ + session
+ + "&g-recaptcha-response=a",
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
+
+ # The recaptcha handler is called with the response given
+ self.assertEqual(len(self.recaptcha_attempts), 1)
+ self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
+
+ # Now we have fufilled the recaptcha fallback step, we can then send a
+ # request to the register API with the session in the authdict.
+ request, channel = self.make_request(
+ "POST", "register", {"auth": {"session": session}}
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel.json_body["user_id"], "@user:test")
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index 6a886ee3b8..f42a8efbf4 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -13,84 +13,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.types
from synapse.api.errors import Codes
-from synapse.http.server import JsonResource
from synapse.rest.client.v2_alpha import filter
-from synapse.types import UserID
-from synapse.util import Clock
from tests import unittest
-from tests.server import (
- ThreadedMemoryReactorClock as MemoryReactorClock,
- make_request,
- render,
- setup_test_homeserver,
-)
PATH_PREFIX = "/_matrix/client/v2_alpha"
-class FilterTestCase(unittest.TestCase):
+class FilterTestCase(unittest.HomeserverTestCase):
- USER_ID = "@apple:test"
+ user_id = "@apple:test"
+ hijack_auth = True
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
- TO_REGISTER = [filter]
+ servlets = [filter.register_servlets]
- def setUp(self):
- self.clock = MemoryReactorClock()
- self.hs_clock = Clock(self.clock)
-
- self.hs = setup_test_homeserver(
- self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
- )
-
- self.auth = self.hs.get_auth()
-
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.USER_ID),
- "token_id": 1,
- "is_guest": False,
- }
-
- def get_user_by_req(request, allow_guest=False, rights="access"):
- return synapse.types.create_requester(
- UserID.from_string(self.USER_ID), 1, False, None
- )
-
- self.auth.get_user_by_access_token = get_user_by_access_token
- self.auth.get_user_by_req = get_user_by_req
-
- self.store = self.hs.get_datastore()
- self.filtering = self.hs.get_filtering()
- self.resource = JsonResource(self.hs)
-
- for r in self.TO_REGISTER:
- r.register_servlets(self.hs, self.resource)
+ def prepare(self, reactor, clock, hs):
+ self.filtering = hs.get_filtering()
+ self.store = hs.get_datastore()
def test_add_filter(self):
- request, channel = make_request(
+ request, channel = self.make_request(
"POST",
- "/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
+ "/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON,
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.json_body, {"filter_id": "0"})
filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
- self.clock.advance(0)
+ self.pump()
self.assertEquals(filter.result, self.EXAMPLE_FILTER)
def test_add_filter_for_other_user(self):
- request, channel = make_request(
+ request, channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
self.EXAMPLE_FILTER_JSON,
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.assertEqual(channel.result["code"], b"403")
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
@@ -98,12 +61,12 @@ class FilterTestCase(unittest.TestCase):
def test_add_filter_non_local_user(self):
_is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False
- request, channel = make_request(
+ request, channel = self.make_request(
"POST",
- "/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
+ "/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON,
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.hs.is_mine = _is_mine
self.assertEqual(channel.result["code"], b"403")
@@ -113,21 +76,21 @@ class FilterTestCase(unittest.TestCase):
filter_id = self.filtering.add_user_filter(
user_localpart="apple", user_filter=self.EXAMPLE_FILTER
)
- self.clock.advance(1)
+ self.reactor.advance(1)
filter_id = filter_id.result
- request, channel = make_request(
- "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.assertEqual(channel.result["code"], b"200")
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
def test_get_filter_non_existant(self):
- request, channel = make_request(
- "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.assertEqual(channel.result["code"], b"400")
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -135,18 +98,18 @@ class FilterTestCase(unittest.TestCase):
# Currently invalid params do not have an appropriate errcode
# in errors.py
def test_get_filter_invalid_id(self):
- request, channel = make_request(
- "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.assertEqual(channel.result["code"], b"400")
# No ID also returns an invalid_id error
def test_get_filter_no_id(self):
- request, channel = make_request(
- "GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.assertEqual(channel.result["code"], b"400")
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 1c128e81f5..753d5c3e80 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -3,22 +3,19 @@ import json
from mock import Mock
from twisted.python import failure
-from twisted.test.proto_helpers import MemoryReactorClock
from synapse.api.errors import InteractiveAuthIncompleteError
-from synapse.http.server import JsonResource
from synapse.rest.client.v2_alpha.register import register_servlets
-from synapse.util import Clock
from tests import unittest
-from tests.server import make_request, render, setup_test_homeserver
-class RegisterRestServletTestCase(unittest.TestCase):
- def setUp(self):
+class RegisterRestServletTestCase(unittest.HomeserverTestCase):
+
+ servlets = [register_servlets]
+
+ def make_homeserver(self, reactor, clock):
- self.clock = MemoryReactorClock()
- self.hs_clock = Clock(self.clock)
self.url = b"/_matrix/client/r0/register"
self.appservice = None
@@ -46,9 +43,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
identity_handler=self.identity_handler,
login_handler=self.login_handler,
)
- self.hs = setup_test_homeserver(
- self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
- )
+ self.hs = self.setup_test_homeserver()
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
@@ -58,8 +53,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = []
- self.resource = JsonResource(self.hs)
- register_servlets(self.hs, self.resource)
+ return self.hs
def test_POST_appservice_registration_valid(self):
user_id = "@kermit:muppet"
@@ -69,10 +63,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
request_data = json.dumps({"username": "kermit"})
- request, channel = make_request(
+ request, channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
det_data = {
@@ -85,25 +79,25 @@ class RegisterRestServletTestCase(unittest.TestCase):
def test_POST_appservice_registration_invalid(self):
self.appservice = None # no application service exists
request_data = json.dumps({"username": "kermit"})
- request, channel = make_request(
+ request, channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- render(request, self.resource, self.clock)
+ self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
def test_POST_bad_password(self):
request_data = json.dumps({"username": "kermit", "password": 666})
- request, channel = make_request(b"POST", self.url, request_data)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request(b"POST", self.url, request_data)
+ self.render(request)
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid password")
def test_POST_bad_username(self):
request_data = json.dumps({"username": 777, "password": "monkey"})
- request, channel = make_request(b"POST", self.url, request_data)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request(b"POST", self.url, request_data)
+ self.render(request)
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid username")
@@ -121,8 +115,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
self.device_handler.check_device_registered = Mock(return_value=device_id)
- request, channel = make_request(b"POST", self.url, request_data)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request(b"POST", self.url, request_data)
+ self.render(request)
det_data = {
"user_id": user_id,
@@ -143,8 +137,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
- request, channel = make_request(b"POST", self.url, request_data)
- render(request, self.resource, self.clock)
+ request, channel = self.make_request(b"POST", self.url, request_data)
+ self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
@@ -155,8 +149,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.config.allow_guest_access = True
self.registration_handler.register = Mock(return_value=(user_id, None))
- request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
- render(request, self.resource, self.clock)
+ request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+ self.render(request)
det_data = {
"user_id": user_id,
@@ -169,8 +163,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
def test_POST_disabled_guest_registration(self):
self.hs.config.allow_guest_access = False
- request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
- render(request, self.resource, self.clock)
+ request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+ self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 4c30c5f258..99b716f00a 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -15,9 +15,11 @@
from mock import Mock
+from synapse.rest.client.v1 import admin, login, room
from synapse.rest.client.v2_alpha import sync
from tests import unittest
+from tests.server import TimedOutException
class FilterTestCase(unittest.HomeserverTestCase):
@@ -65,3 +67,124 @@ class FilterTestCase(unittest.HomeserverTestCase):
["next_batch", "rooms", "account_data", "to_device", "device_lists"]
).issubset(set(channel.json_body.keys()))
)
+
+
+class SyncTypingTests(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+ user_id = True
+ hijack_auth = False
+
+ def test_sync_backwards_typing(self):
+ """
+ If the typing serial goes backwards and the typing handler is then reset
+ (such as when the master restarts and sets the typing serial to 0), we
+ do not incorrectly return typing information that had a serial greater
+ than the now-reset serial.
+ """
+ typing_url = "/rooms/%s/typing/%s?access_token=%s"
+ sync_url = "/sync?timeout=3000000&access_token=%s&since=%s"
+
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("otheruser", "pass")
+ other_access_token = self.login("otheruser", "pass")
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # Invite the other person
+ self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id)
+
+ # The other user joins
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # The other user sends some messages
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.helper.send(room, body="There!", tok=other_access_token)
+
+ # Start typing.
+ request, channel = self.make_request(
+ "PUT",
+ typing_url % (room, other_user_id, other_access_token),
+ b'{"typing": true, "timeout": 30000}',
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ request, channel = self.make_request(
+ "GET", "/sync?access_token=%s" % (access_token,)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ next_batch = channel.json_body["next_batch"]
+
+ # Stop typing.
+ request, channel = self.make_request(
+ "PUT",
+ typing_url % (room, other_user_id, other_access_token),
+ b'{"typing": false}',
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # Start typing.
+ request, channel = self.make_request(
+ "PUT",
+ typing_url % (room, other_user_id, other_access_token),
+ b'{"typing": true, "timeout": 30000}',
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # Should return immediately
+ request, channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ next_batch = channel.json_body["next_batch"]
+
+ # Reset typing serial back to 0, as if the master had.
+ typing = self.hs.get_typing_handler()
+ typing._latest_room_serial = 0
+
+ # Since it checks the state token, we need some state to update to
+ # invalidate the stream token.
+ self.helper.send(room, body="There!", tok=other_access_token)
+
+ request, channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ next_batch = channel.json_body["next_batch"]
+
+ # This should time out! But it does not, because our stream token is
+ # ahead, and therefore it's saying the typing (that we've actually
+ # already seen) is new, since it's got a token above our new, now-reset
+ # stream token.
+ request, channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ next_batch = channel.json_body["next_batch"]
+
+ # Clear the typing information, so that it doesn't think everything is
+ # in the future.
+ typing._reset()
+
+ # Now it SHOULD fail as it never completes!
+ request, channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch)
+ )
+ self.assertRaises(TimedOutException, self.render, request)
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index a86901c2d8..ad5e9a612f 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -17,15 +17,21 @@
import os
import shutil
import tempfile
+from binascii import unhexlify
from mock import Mock
+from six.moves.urllib import parse
from twisted.internet import defer, reactor
+from twisted.internet.defer import Deferred
+from synapse.config.repository import MediaStorageProviderConfig
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
+from synapse.util.logcontext import make_deferred_yieldable
+from synapse.util.module_loader import load_module
from tests import unittest
@@ -83,3 +89,143 @@ class MediaStorageTests(unittest.TestCase):
body = f.read()
self.assertEqual(test_body, body)
+
+
+class MediaRepoTests(unittest.HomeserverTestCase):
+
+ hijack_auth = True
+ user_id = "@test:user"
+
+ def make_homeserver(self, reactor, clock):
+
+ self.fetches = []
+
+ def get_file(destination, path, output_stream, args=None, max_size=None):
+ """
+ Returns tuple[int,dict,str,int] of file length, response headers,
+ absolute URI, and response code.
+ """
+
+ def write_to(r):
+ data, response = r
+ output_stream.write(data)
+ return response
+
+ d = Deferred()
+ d.addCallback(write_to)
+ self.fetches.append((d, destination, path, args))
+ return make_deferred_yieldable(d)
+
+ client = Mock()
+ client.get_file = get_file
+
+ self.storage_path = self.mktemp()
+ os.mkdir(self.storage_path)
+
+ config = self.default_config()
+ config.media_store_path = self.storage_path
+ config.thumbnail_requirements = {}
+ config.max_image_pixels = 2000000
+
+ provider_config = {
+ "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+ "store_local": True,
+ "store_synchronous": False,
+ "store_remote": True,
+ "config": {"directory": self.storage_path},
+ }
+
+ loaded = list(load_module(provider_config)) + [
+ MediaStorageProviderConfig(False, False, False)
+ ]
+
+ config.media_storage_providers = [loaded]
+
+ hs = self.setup_test_homeserver(config=config, http_client=client)
+
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+
+ self.media_repo = hs.get_media_repository_resource()
+ self.download_resource = self.media_repo.children[b'download']
+
+ # smol png
+ self.end_content = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ def _req(self, content_disposition):
+
+ request, channel = self.make_request(
+ "GET", "example.com/12345", shorthand=False
+ )
+ request.render(self.download_resource)
+ self.pump()
+
+ # We've made one fetch, to example.com, using the media URL, and asking
+ # the other server not to do a remote fetch
+ self.assertEqual(len(self.fetches), 1)
+ self.assertEqual(self.fetches[0][1], "example.com")
+ self.assertEqual(
+ self.fetches[0][2], "/_matrix/media/v1/download/example.com/12345"
+ )
+ self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
+
+ headers = {
+ b"Content-Length": [b"%d" % (len(self.end_content))],
+ b"Content-Type": [b'image/png'],
+ }
+ if content_disposition:
+ headers[b"Content-Disposition"] = [content_disposition]
+
+ self.fetches[0][0].callback(
+ (self.end_content, (len(self.end_content), headers))
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
+ return channel
+
+ def test_disposition_filename_ascii(self):
+ """
+ If the filename is filename=<ascii> then Synapse will decode it as an
+ ASCII string, and use filename= in the response.
+ """
+ channel = self._req(b"inline; filename=out.png")
+
+ headers = channel.headers
+ self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Disposition"), [b"inline; filename=out.png"]
+ )
+
+ def test_disposition_filenamestar_utf8escaped(self):
+ """
+ If the filename is filename=*utf8''<utf8 escaped> then Synapse will
+ correctly decode it as the UTF-8 string, and use filename* in the
+ response.
+ """
+ filename = parse.quote(u"\u2603".encode('utf8')).encode('ascii')
+ channel = self._req(b"inline; filename*=utf-8''" + filename + b".png")
+
+ headers = channel.headers
+ self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+ self.assertEqual(
+ headers.getRawHeaders(b"Content-Disposition"),
+ [b"inline; filename*=utf-8''" + filename + b".png"],
+ )
+
+ def test_disposition_none(self):
+ """
+ If there is no filename, one isn't passed on in the Content-Disposition
+ of the request.
+ """
+ channel = self._req(None)
+
+ headers = channel.headers
+ self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
+ self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
new file mode 100644
index 0000000000..c62f71b44a
--- /dev/null
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -0,0 +1,242 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from mock import Mock
+
+from twisted.internet.defer import Deferred
+
+from synapse.config.repository import MediaStorageProviderConfig
+from synapse.util.logcontext import make_deferred_yieldable
+from synapse.util.module_loader import load_module
+
+from tests import unittest
+
+
+class URLPreviewTests(unittest.HomeserverTestCase):
+
+ hijack_auth = True
+ user_id = "@test:user"
+
+ def make_homeserver(self, reactor, clock):
+
+ self.storage_path = self.mktemp()
+ os.mkdir(self.storage_path)
+
+ config = self.default_config()
+ config.url_preview_enabled = True
+ config.max_spider_size = 9999999
+ config.url_preview_url_blacklist = []
+ config.media_store_path = self.storage_path
+
+ provider_config = {
+ "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+ "store_local": True,
+ "store_synchronous": False,
+ "store_remote": True,
+ "config": {"directory": self.storage_path},
+ }
+
+ loaded = list(load_module(provider_config)) + [
+ MediaStorageProviderConfig(False, False, False)
+ ]
+
+ config.media_storage_providers = [loaded]
+
+ hs = self.setup_test_homeserver(config=config)
+
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+
+ self.fetches = []
+
+ def get_file(url, output_stream, max_size):
+ """
+ Returns tuple[int,dict,str,int] of file length, response headers,
+ absolute URI, and response code.
+ """
+
+ def write_to(r):
+ data, response = r
+ output_stream.write(data)
+ return response
+
+ d = Deferred()
+ d.addCallback(write_to)
+ self.fetches.append((d, url))
+ return make_deferred_yieldable(d)
+
+ client = Mock()
+ client.get_file = get_file
+
+ self.media_repo = hs.get_media_repository_resource()
+ preview_url = self.media_repo.children[b'preview_url']
+ preview_url.client = client
+ self.preview_url = preview_url
+
+ def test_cache_returns_correct_type(self):
+
+ request, channel = self.make_request(
+ "GET", "url_preview?url=matrix.org", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ # We've made one fetch
+ self.assertEqual(len(self.fetches), 1)
+
+ end_content = (
+ b'<html><head>'
+ b'<meta property="og:title" content="~matrix~" />'
+ b'<meta property="og:description" content="hi" />'
+ b'</head></html>'
+ )
+
+ self.fetches[0][0].callback(
+ (
+ end_content,
+ (
+ len(end_content),
+ {
+ b"Content-Length": [b"%d" % (len(end_content))],
+ b"Content-Type": [b'text/html; charset="utf8"'],
+ },
+ "https://example.com",
+ 200,
+ ),
+ )
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+ )
+
+ # Check the cache returns the correct response
+ request, channel = self.make_request(
+ "GET", "url_preview?url=matrix.org", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ # Only one fetch, still, since we'll lean on the cache
+ self.assertEqual(len(self.fetches), 1)
+
+ # Check the cache response has the same content
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+ )
+
+ # Clear the in-memory cache
+ self.assertIn("matrix.org", self.preview_url._cache)
+ self.preview_url._cache.pop("matrix.org")
+ self.assertNotIn("matrix.org", self.preview_url._cache)
+
+ # Check the database cache returns the correct response
+ request, channel = self.make_request(
+ "GET", "url_preview?url=matrix.org", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ # Only one fetch, still, since we'll lean on the cache
+ self.assertEqual(len(self.fetches), 1)
+
+ # Check the cache response has the same content
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+ )
+
+ def test_non_ascii_preview_httpequiv(self):
+
+ request, channel = self.make_request(
+ "GET", "url_preview?url=matrix.org", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ # We've made one fetch
+ self.assertEqual(len(self.fetches), 1)
+
+ end_content = (
+ b'<html><head>'
+ b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>'
+ b'<meta property="og:title" content="\xe4\xea\xe0" />'
+ b'<meta property="og:description" content="hi" />'
+ b'</head></html>'
+ )
+
+ self.fetches[0][0].callback(
+ (
+ end_content,
+ (
+ len(end_content),
+ {
+ b"Content-Length": [b"%d" % (len(end_content))],
+ # This charset=utf-8 should be ignored, because the
+ # document has a meta tag overriding it.
+ b"Content-Type": [b'text/html; charset="utf8"'],
+ },
+ "https://example.com",
+ 200,
+ ),
+ )
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
+
+ def test_non_ascii_preview_content_type(self):
+
+ request, channel = self.make_request(
+ "GET", "url_preview?url=matrix.org", shorthand=False
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ # We've made one fetch
+ self.assertEqual(len(self.fetches), 1)
+
+ end_content = (
+ b'<html><head>'
+ b'<meta property="og:title" content="\xe4\xea\xe0" />'
+ b'<meta property="og:description" content="hi" />'
+ b'</head></html>'
+ )
+
+ self.fetches[0][0].callback(
+ (
+ end_content,
+ (
+ len(end_content),
+ {
+ b"Content-Length": [b"%d" % (len(end_content))],
+ b"Content-Type": [b'text/html; charset="windows-1251"'],
+ },
+ "https://example.com",
+ 200,
+ ),
+ )
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
diff --git a/tests/scripts/__init__.py b/tests/scripts/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/scripts/__init__.py
diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py
new file mode 100644
index 0000000000..6f56893f5e
--- /dev/null
+++ b/tests/scripts/test_new_matrix_user.py
@@ -0,0 +1,160 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from synapse._scripts.register_new_matrix_user import request_registration
+
+from tests.unittest import TestCase
+
+
+class RegisterTestCase(TestCase):
+ def test_success(self):
+ """
+ The script will fetch a nonce, and then generate a MAC with it, and then
+ post that MAC.
+ """
+
+ def get(url, verify=None):
+ r = Mock()
+ r.status_code = 200
+ r.json = lambda: {"nonce": "a"}
+ return r
+
+ def post(url, json=None, verify=None):
+ # Make sure we are sent the correct info
+ self.assertEqual(json["username"], "user")
+ self.assertEqual(json["password"], "pass")
+ self.assertEqual(json["nonce"], "a")
+ # We want a 40-char hex MAC
+ self.assertEqual(len(json["mac"]), 40)
+
+ r = Mock()
+ r.status_code = 200
+ return r
+
+ requests = Mock()
+ requests.get = get
+ requests.post = post
+
+ # The fake stdout will be written here
+ out = []
+ err_code = []
+
+ request_registration(
+ "user",
+ "pass",
+ "matrix.org",
+ "shared",
+ admin=False,
+ requests=requests,
+ _print=out.append,
+ exit=err_code.append,
+ )
+
+ # We should get the success message making sure everything is OK.
+ self.assertIn("Success!", out)
+
+ # sys.exit shouldn't have been called.
+ self.assertEqual(err_code, [])
+
+ def test_failure_nonce(self):
+ """
+ If the script fails to fetch a nonce, it throws an error and quits.
+ """
+
+ def get(url, verify=None):
+ r = Mock()
+ r.status_code = 404
+ r.reason = "Not Found"
+ r.json = lambda: {"not": "error"}
+ return r
+
+ requests = Mock()
+ requests.get = get
+
+ # The fake stdout will be written here
+ out = []
+ err_code = []
+
+ request_registration(
+ "user",
+ "pass",
+ "matrix.org",
+ "shared",
+ admin=False,
+ requests=requests,
+ _print=out.append,
+ exit=err_code.append,
+ )
+
+ # Exit was called
+ self.assertEqual(err_code, [1])
+
+ # We got an error message
+ self.assertIn("ERROR! Received 404 Not Found", out)
+ self.assertNotIn("Success!", out)
+
+ def test_failure_post(self):
+ """
+ The script will fetch a nonce, and then if the final POST fails, will
+ report an error and quit.
+ """
+
+ def get(url, verify=None):
+ r = Mock()
+ r.status_code = 200
+ r.json = lambda: {"nonce": "a"}
+ return r
+
+ def post(url, json=None, verify=None):
+ # Make sure we are sent the correct info
+ self.assertEqual(json["username"], "user")
+ self.assertEqual(json["password"], "pass")
+ self.assertEqual(json["nonce"], "a")
+ # We want a 40-char hex MAC
+ self.assertEqual(len(json["mac"]), 40)
+
+ r = Mock()
+ # Then 500 because we're jerks
+ r.status_code = 500
+ r.reason = "Broken"
+ return r
+
+ requests = Mock()
+ requests.get = get
+ requests.post = post
+
+ # The fake stdout will be written here
+ out = []
+ err_code = []
+
+ request_registration(
+ "user",
+ "pass",
+ "matrix.org",
+ "shared",
+ admin=False,
+ requests=requests,
+ _print=out.append,
+ exit=err_code.append,
+ )
+
+ # Exit was called
+ self.assertEqual(err_code, [1])
+
+ # We got an error message
+ self.assertIn("ERROR! Received 500 Broken", out)
+ self.assertNotIn("Success!", out)
diff --git a/tests/server.py b/tests/server.py
index 7bee58dff1..ceec2f2d4e 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -14,6 +14,8 @@ from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactorClock
+from twisted.web.http import unquote
+from twisted.web.http_headers import Headers
from synapse.http.site import SynapseRequest
from synapse.util import Clock
@@ -21,6 +23,12 @@ from synapse.util import Clock
from tests.utils import setup_test_homeserver as _sth
+class TimedOutException(Exception):
+ """
+ A web query timed out.
+ """
+
+
@attr.s
class FakeChannel(object):
"""
@@ -28,6 +36,7 @@ class FakeChannel(object):
wire).
"""
+ _reactor = attr.ib()
result = attr.ib(default=attr.Factory(dict))
_producer = None
@@ -43,6 +52,15 @@ class FakeChannel(object):
raise Exception("No result yet.")
return int(self.result["code"])
+ @property
+ def headers(self):
+ if not self.result:
+ raise Exception("No result yet.")
+ h = Headers()
+ for i in self.result["headers"]:
+ h.addRawHeader(*i)
+ return h
+
def writeHeaders(self, version, code, reason, headers):
self.result["version"] = version
self.result["code"] = code
@@ -50,6 +68,8 @@ class FakeChannel(object):
self.result["headers"] = headers
def write(self, content):
+ assert isinstance(content, bytes), "Should be bytes! " + repr(content)
+
if "body" not in self.result:
self.result["body"] = b""
@@ -57,6 +77,15 @@ class FakeChannel(object):
def registerProducer(self, producer, streaming):
self._producer = producer
+ self.producerStreaming = streaming
+
+ def _produce():
+ if self._producer:
+ self._producer.resumeProducing()
+ self._reactor.callLater(0.1, _produce)
+
+ if not streaming:
+ self._reactor.callLater(0.0, _produce)
def unregisterProducer(self):
if self._producer is None:
@@ -98,10 +127,30 @@ class FakeSite:
return FakeLogger()
-def make_request(method, path, content=b"", access_token=None, request=SynapseRequest):
+def make_request(
+ reactor,
+ method,
+ path,
+ content=b"",
+ access_token=None,
+ request=SynapseRequest,
+ shorthand=True,
+):
"""
Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath.
+
+ Args:
+ method (bytes/unicode): The HTTP request method ("verb").
+ path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
+ escaped UTF-8 & spaces and such).
+ content (bytes or dict): The body of the request. JSON-encoded, if
+ a dict.
+ shorthand: Whether to try and be helpful and prefix the given URL
+ with the usual REST API path, if it doesn't contain it.
+
+ Returns:
+ A synapse.http.site.SynapseRequest.
"""
if not isinstance(method, bytes):
method = method.encode('ascii')
@@ -109,23 +158,29 @@ def make_request(method, path, content=b"", access_token=None, request=SynapseRe
if not isinstance(path, bytes):
path = path.encode('ascii')
- # Decorate it to be the full path
- if not path.startswith(b"/_matrix"):
+ # Decorate it to be the full path, if we're using shorthand
+ if shorthand and not path.startswith(b"/_matrix"):
path = b"/_matrix/client/r0/" + path
path = path.replace(b"//", b"/")
+ if not path.startswith(b"/"):
+ path = b"/" + path
+
if isinstance(content, text_type):
content = content.encode('utf8')
site = FakeSite()
- channel = FakeChannel()
+ channel = FakeChannel(reactor)
req = request(site, channel)
req.process = lambda: b""
req.content = BytesIO(content)
+ req.postpath = list(map(unquote, path[1:].split(b'/')))
if access_token:
- req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token)
+ req.requestHeaders.addRawHeader(
+ b"Authorization", b"Bearer " + access_token.encode('ascii')
+ )
if content:
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
@@ -151,7 +206,7 @@ def wait_until_result(clock, request, timeout=100):
x += 1
if x > timeout:
- raise Exception("Timed out waiting for request to finish.")
+ raise TimedOutException("Timed out waiting for request to finish.")
clock.advance(0.1)
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 4701eedd45..b1551df7ca 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -4,7 +4,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, ServerNoticeMsgType
from synapse.api.errors import ResourceLimitError
-from synapse.handlers.auth import AuthHandler
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
)
@@ -13,17 +12,10 @@ from tests import unittest
from tests.utils import setup_test_homeserver
-class AuthHandlers(object):
- def __init__(self, hs):
- self.auth_handler = AuthHandler(hs)
-
-
class TestResourceLimitsServerNotices(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
- self.hs.handlers = AuthHandlers(self.hs)
- self.auth_handler = self.hs.handlers.auth_handler
+ self.hs = yield setup_test_homeserver(self.addCleanup)
self.server_notices_sender = self.hs.get_server_notices_sender()
# relying on [1] is far from ideal, but the only case where
diff --git a/tests/state/__init__.py b/tests/state/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/state/__init__.py
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
new file mode 100644
index 0000000000..2e073a3afc
--- /dev/null
+++ b/tests/state/test_v2.py
@@ -0,0 +1,759 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+
+from six.moves import zip
+
+import attr
+
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.event_auth import auth_types_for_event
+from synapse.events import FrozenEvent
+from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
+from synapse.types import EventID
+
+from tests import unittest
+
+ALICE = "@alice:example.com"
+BOB = "@bob:example.com"
+CHARLIE = "@charlie:example.com"
+EVELYN = "@evelyn:example.com"
+ZARA = "@zara:example.com"
+
+ROOM_ID = "!test:example.com"
+
+MEMBERSHIP_CONTENT_JOIN = {"membership": Membership.JOIN}
+MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN}
+
+
+ORIGIN_SERVER_TS = 0
+
+
+class FakeEvent(object):
+ """A fake event we use as a convenience.
+
+ NOTE: Again as a convenience we use "node_ids" rather than event_ids to
+ refer to events. The event_id has node_id as localpart and example.com
+ as domain.
+ """
+ def __init__(self, id, sender, type, state_key, content):
+ self.node_id = id
+ self.event_id = EventID(id, "example.com").to_string()
+ self.sender = sender
+ self.type = type
+ self.state_key = state_key
+ self.content = content
+
+ def to_event(self, auth_events, prev_events):
+ """Given the auth_events and prev_events, convert to a Frozen Event
+
+ Args:
+ auth_events (list[str]): list of event_ids
+ prev_events (list[str]): list of event_ids
+
+ Returns:
+ FrozenEvent
+ """
+ global ORIGIN_SERVER_TS
+
+ ts = ORIGIN_SERVER_TS
+ ORIGIN_SERVER_TS = ORIGIN_SERVER_TS + 1
+
+ event_dict = {
+ "auth_events": [(a, {}) for a in auth_events],
+ "prev_events": [(p, {}) for p in prev_events],
+ "event_id": self.node_id,
+ "sender": self.sender,
+ "type": self.type,
+ "content": self.content,
+ "origin_server_ts": ts,
+ "room_id": ROOM_ID,
+ }
+
+ if self.state_key is not None:
+ event_dict["state_key"] = self.state_key
+
+ return FrozenEvent(event_dict)
+
+
+# All graphs start with this set of events
+INITIAL_EVENTS = [
+ FakeEvent(
+ id="CREATE",
+ sender=ALICE,
+ type=EventTypes.Create,
+ state_key="",
+ content={"creator": ALICE},
+ ),
+ FakeEvent(
+ id="IMA",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=ALICE,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ),
+ FakeEvent(
+ id="IPOWER",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={"users": {ALICE: 100}},
+ ),
+ FakeEvent(
+ id="IJR",
+ sender=ALICE,
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.PUBLIC},
+ ),
+ FakeEvent(
+ id="IMB",
+ sender=BOB,
+ type=EventTypes.Member,
+ state_key=BOB,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ),
+ FakeEvent(
+ id="IMC",
+ sender=CHARLIE,
+ type=EventTypes.Member,
+ state_key=CHARLIE,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ),
+ FakeEvent(
+ id="IMZ",
+ sender=ZARA,
+ type=EventTypes.Member,
+ state_key=ZARA,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ),
+ FakeEvent(
+ id="START",
+ sender=ZARA,
+ type=EventTypes.Message,
+ state_key=None,
+ content={},
+ ),
+ FakeEvent(
+ id="END",
+ sender=ZARA,
+ type=EventTypes.Message,
+ state_key=None,
+ content={},
+ ),
+]
+
+INITIAL_EDGES = [
+ "START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE",
+]
+
+
+class StateTestCase(unittest.TestCase):
+ def test_ban_vs_pl(self):
+ events = [
+ FakeEvent(
+ id="PA",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ }
+ },
+ ),
+ FakeEvent(
+ id="MA",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=ALICE,
+ content={"membership": Membership.JOIN},
+ ),
+ FakeEvent(
+ id="MB",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=BOB,
+ content={"membership": Membership.BAN},
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ ]
+
+ edges = [
+ ["END", "MB", "MA", "PA", "START"],
+ ["END", "PB", "PA"],
+ ]
+
+ expected_state_ids = ["PA", "MA", "MB"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_join_rule_evasion(self):
+ events = [
+ FakeEvent(
+ id="JR",
+ sender=ALICE,
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rules": JoinRules.PRIVATE},
+ ),
+ FakeEvent(
+ id="ME",
+ sender=EVELYN,
+ type=EventTypes.Member,
+ state_key=EVELYN,
+ content={"membership": Membership.JOIN},
+ ),
+ ]
+
+ edges = [
+ ["END", "JR", "START"],
+ ["END", "ME", "START"],
+ ]
+
+ expected_state_ids = ["JR"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_offtopic_pl(self):
+ events = [
+ FakeEvent(
+ id="PA",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ }
+ },
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ CHARLIE: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="PC",
+ sender=CHARLIE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ CHARLIE: 0,
+ },
+ },
+ ),
+ ]
+
+ edges = [
+ ["END", "PC", "PB", "PA", "START"],
+ ["END", "PA"],
+ ]
+
+ expected_state_ids = ["PC"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_topic_basic(self):
+ events = [
+ FakeEvent(
+ id="T1",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA1",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T2",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA2",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 0,
+ },
+ },
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T3",
+ sender=BOB,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ ]
+
+ edges = [
+ ["END", "PA2", "T2", "PA1", "T1", "START"],
+ ["END", "T3", "PB", "PA1"],
+ ]
+
+ expected_state_ids = ["PA2", "T2"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_topic_reset(self):
+ events = [
+ FakeEvent(
+ id="T1",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T2",
+ sender=BOB,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="MB",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=BOB,
+ content={"membership": Membership.BAN},
+ ),
+ ]
+
+ edges = [
+ ["END", "MB", "T2", "PA", "T1", "START"],
+ ["END", "T1"],
+ ]
+
+ expected_state_ids = ["T1", "MB", "PA"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_topic(self):
+ events = [
+ FakeEvent(
+ id="T1",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA1",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T2",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA2",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 0,
+ },
+ },
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T3",
+ sender=BOB,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="MZ1",
+ sender=ZARA,
+ type=EventTypes.Message,
+ state_key=None,
+ content={},
+ ),
+ FakeEvent(
+ id="T4",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ ]
+
+ edges = [
+ ["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"],
+ ["END", "MZ1", "T3", "PB", "PA1"],
+ ]
+
+ expected_state_ids = ["T4", "PA2"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def do_check(self, events, edges, expected_state_ids):
+ """Take a list of events and edges and calculate the state of the
+ graph at END, and asserts it matches `expected_state_ids`
+
+ Args:
+ events (list[FakeEvent])
+ edges (list[list[str]]): A list of chains of event edges, e.g.
+ `[[A, B, C]]` are edges A->B and B->C.
+ expected_state_ids (list[str]): The expected state at END, (excluding
+ the keys that haven't changed since START).
+ """
+ # We want to sort the events into topological order for processing.
+ graph = {}
+
+ # node_id -> FakeEvent
+ fake_event_map = {}
+
+ for ev in itertools.chain(INITIAL_EVENTS, events):
+ graph[ev.node_id] = set()
+ fake_event_map[ev.node_id] = ev
+
+ for a, b in pairwise(INITIAL_EDGES):
+ graph[a].add(b)
+
+ for edge_list in edges:
+ for a, b in pairwise(edge_list):
+ graph[a].add(b)
+
+ # event_id -> FrozenEvent
+ event_map = {}
+ # node_id -> state
+ state_at_event = {}
+
+ # We copy the map as the sort consumes the graph
+ graph_copy = {k: set(v) for k, v in graph.items()}
+
+ for node_id in lexicographical_topological_sort(graph_copy, key=lambda e: e):
+ fake_event = fake_event_map[node_id]
+ event_id = fake_event.event_id
+
+ prev_events = list(graph[node_id])
+
+ if len(prev_events) == 0:
+ state_before = {}
+ elif len(prev_events) == 1:
+ state_before = dict(state_at_event[prev_events[0]])
+ else:
+ state_d = resolve_events_with_store(
+ [state_at_event[n] for n in prev_events],
+ event_map=event_map,
+ state_res_store=TestStateResolutionStore(event_map),
+ )
+
+ state_before = self.successResultOf(state_d)
+
+ state_after = dict(state_before)
+ if fake_event.state_key is not None:
+ state_after[(fake_event.type, fake_event.state_key)] = event_id
+
+ auth_types = set(auth_types_for_event(fake_event))
+
+ auth_events = []
+ for key in auth_types:
+ if key in state_before:
+ auth_events.append(state_before[key])
+
+ event = fake_event.to_event(auth_events, prev_events)
+
+ state_at_event[node_id] = state_after
+ event_map[event_id] = event
+
+ expected_state = {}
+ for node_id in expected_state_ids:
+ # expected_state_ids are node IDs rather than event IDs,
+ # so we have to convert
+ event_id = EventID(node_id, "example.com").to_string()
+ event = event_map[event_id]
+
+ key = (event.type, event.state_key)
+
+ expected_state[key] = event_id
+
+ start_state = state_at_event["START"]
+ end_state = {
+ key: value
+ for key, value in state_at_event["END"].items()
+ if key in expected_state or start_state.get(key) != value
+ }
+
+ self.assertEqual(expected_state, end_state)
+
+
+class LexicographicalTestCase(unittest.TestCase):
+ def test_simple(self):
+ graph = {
+ "l": {"o"},
+ "m": {"n", "o"},
+ "n": {"o"},
+ "o": set(),
+ "p": {"o"},
+ }
+
+ res = list(lexicographical_topological_sort(graph, key=lambda x: x))
+
+ self.assertEqual(["o", "l", "n", "m", "p"], res)
+
+
+class SimpleParamStateTestCase(unittest.TestCase):
+ def setUp(self):
+ # We build up a simple DAG.
+
+ event_map = {}
+
+ create_event = FakeEvent(
+ id="CREATE",
+ sender=ALICE,
+ type=EventTypes.Create,
+ state_key="",
+ content={"creator": ALICE},
+ ).to_event([], [])
+ event_map[create_event.event_id] = create_event
+
+ alice_member = FakeEvent(
+ id="IMA",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=ALICE,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ).to_event([create_event.event_id], [create_event.event_id])
+ event_map[alice_member.event_id] = alice_member
+
+ join_rules = FakeEvent(
+ id="IJR",
+ sender=ALICE,
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.PUBLIC},
+ ).to_event(
+ auth_events=[create_event.event_id, alice_member.event_id],
+ prev_events=[alice_member.event_id],
+ )
+ event_map[join_rules.event_id] = join_rules
+
+ # Bob and Charlie join at the same time, so there is a fork
+ bob_member = FakeEvent(
+ id="IMB",
+ sender=BOB,
+ type=EventTypes.Member,
+ state_key=BOB,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ).to_event(
+ auth_events=[create_event.event_id, join_rules.event_id],
+ prev_events=[join_rules.event_id],
+ )
+ event_map[bob_member.event_id] = bob_member
+
+ charlie_member = FakeEvent(
+ id="IMC",
+ sender=CHARLIE,
+ type=EventTypes.Member,
+ state_key=CHARLIE,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ).to_event(
+ auth_events=[create_event.event_id, join_rules.event_id],
+ prev_events=[join_rules.event_id],
+ )
+ event_map[charlie_member.event_id] = charlie_member
+
+ self.event_map = event_map
+ self.create_event = create_event
+ self.alice_member = alice_member
+ self.join_rules = join_rules
+ self.bob_member = bob_member
+ self.charlie_member = charlie_member
+
+ self.state_at_bob = {
+ (e.type, e.state_key): e.event_id
+ for e in [create_event, alice_member, join_rules, bob_member]
+ }
+
+ self.state_at_charlie = {
+ (e.type, e.state_key): e.event_id
+ for e in [create_event, alice_member, join_rules, charlie_member]
+ }
+
+ self.expected_combined_state = {
+ (e.type, e.state_key): e.event_id
+ for e in [create_event, alice_member, join_rules, bob_member, charlie_member]
+ }
+
+ def test_event_map_none(self):
+ # Test that we correctly handle passing `None` as the event_map
+
+ state_d = resolve_events_with_store(
+ [self.state_at_bob, self.state_at_charlie],
+ event_map=None,
+ state_res_store=TestStateResolutionStore(self.event_map),
+ )
+
+ state = self.successResultOf(state_d)
+
+ self.assert_dict(self.expected_combined_state, state)
+
+
+def pairwise(iterable):
+ "s -> (s0,s1), (s1,s2), (s2, s3), ..."
+ a, b = itertools.tee(iterable)
+ next(b, None)
+ return zip(a, b)
+
+
+@attr.s
+class TestStateResolutionStore(object):
+ event_map = attr.ib()
+
+ def get_events(self, event_ids, allow_rejected=False):
+ """Get events from the database
+
+ Args:
+ event_ids (list): The event_ids of the events to fetch
+ allow_rejected (bool): If True return rejected events.
+
+ Returns:
+ Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+ """
+
+ return {
+ eid: self.event_map[eid]
+ for eid in event_ids
+ if eid in self.event_map
+ }
+
+ def get_auth_chain(self, event_ids):
+ """Gets the full auth chain for a set of events (including rejected
+ events).
+
+ Includes the given event IDs in the result.
+
+ Note that:
+ 1. All events must be state events.
+ 2. For v1 rooms this may not have the full auth chain in the
+ presence of rejected events
+
+ Args:
+ event_ids (list): The event IDs of the events to fetch the auth
+ chain for. Must be state events.
+
+ Returns:
+ Deferred[list[str]]: List of event IDs of the auth chain.
+ """
+
+ # Simple DFS for auth chain
+ result = set()
+ stack = list(event_ids)
+ while stack:
+ event_id = stack.pop()
+ if event_id in result:
+ continue
+
+ result.add(event_id)
+
+ event = self.event_map[event_id]
+ for aid in event.auth_event_ids():
+ stack.append(aid)
+
+ return list(result)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 8f0aaece40..b83f7336d3 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -45,6 +45,21 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
self.assertDictContainsSubset({"keys": json, "device_display_name": None}, dev)
@defer.inlineCallbacks
+ def test_reupload_key(self):
+ now = 1470174257070
+ json = {"key": "value"}
+
+ yield self.store.store_device("user", "device", None)
+
+ changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
+ self.assertTrue(changed)
+
+ # If we try to upload the same key then we should be told nothing
+ # changed
+ changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
+ self.assertFalse(changed)
+
+ @defer.inlineCallbacks
def test_get_key_with_device_name(self):
now = 1470174257070
json = {"key": "value"}
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 686f12a0dc..8664bc3d54 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -52,7 +52,10 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
now = int(self.hs.get_clock().time_msec())
self.store.user_add_threepid(user1, "email", user1_email, now, now)
self.store.user_add_threepid(user2, "email", user2_email, now, now)
- self.store.initialise_reserved_users(threepids)
+
+ self.store.runInteraction(
+ "initialise", self.store._initialise_reserved_users, threepids
+ )
self.pump()
active_count = self.store.get_monthly_active_count()
@@ -199,7 +202,10 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
{'medium': 'email', 'address': user2_email},
]
self.hs.config.mau_limits_reserved_threepids = threepids
- self.store.initialise_reserved_users(threepids)
+ self.store.runInteraction(
+ "initialise", self.store._initialise_reserved_users, threepids
+ )
+
self.pump()
count = self.store.get_registered_reserved_users_count()
self.assertEquals(self.get_success(count), 0)
@@ -214,3 +220,28 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
self.store.user_add_threepid(user2, "email", user2_email, now, now)
count = self.store.get_registered_reserved_users_count()
self.assertEquals(self.get_success(count), len(threepids))
+
+ def test_track_monthly_users_without_cap(self):
+ self.hs.config.limit_usage_by_mau = False
+ self.hs.config.mau_stats_only = True
+ self.hs.config.max_mau_value = 1 # should not matter
+
+ count = self.store.get_monthly_active_count()
+ self.assertEqual(0, self.get_success(count))
+
+ self.store.upsert_monthly_active_user("@user1:server")
+ self.store.upsert_monthly_active_user("@user2:server")
+ self.pump()
+
+ count = self.store.get_monthly_active_count()
+ self.assertEqual(2, self.get_success(count))
+
+ def test_no_users_when_not_tracking(self):
+ self.hs.config.limit_usage_by_mau = False
+ self.hs.config.mau_stats_only = False
+ self.store.upsert_monthly_active_user = Mock()
+
+ self.store.populate_monthly_active_users("@user:sever")
+ self.pump()
+
+ self.store.upsert_monthly_active_user.assert_not_called()
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index b9c5b39d59..086a39d834 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -18,6 +18,7 @@ import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
+from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
import tests.unittest
@@ -148,7 +149,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we get the full state as of the final event
state = yield self.store.get_state_for_event(
- e5.event_id, None, filtered_types=None
+ e5.event_id,
)
self.assertIsNotNone(e4)
@@ -166,33 +167,35 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can filter to the m.room.name event (with a '' state key)
state = yield self.store.get_state_for_event(
- e5.event_id, [(EventTypes.Name, '')], filtered_types=None
+ e5.event_id, StateFilter.from_types([(EventTypes.Name, '')])
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
state = yield self.store.get_state_for_event(
- e5.event_id, [(EventTypes.Name, None)], filtered_types=None
+ e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
state = yield self.store.get_state_for_event(
- e5.event_id, [(EventTypes.Member, None)], filtered_types=None
+ e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
self.assertStateMapEqual(
{(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
)
- # check we can use filtered_types to grab a specific room member
- # without filtering out the other event types
+ # check we can grab a specific room member without filtering out the
+ # other event types
state = yield self.store.get_state_for_event(
e5.event_id,
- [(EventTypes.Member, self.u_alice.to_string())],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: {self.u_alice.to_string()}},
+ include_others=True,
+ )
)
self.assertStateMapEqual(
@@ -204,10 +207,12 @@ class StateStoreTestCase(tests.unittest.TestCase):
state,
)
- # check that types=[], filtered_types=[EventTypes.Member]
- # doesn't return all members
+ # check that we can grab everything except members
state = yield self.store.get_state_for_event(
- e5.event_id, [], filtered_types=[EventTypes.Member]
+ e5.event_id, state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
)
self.assertStateMapEqual(
@@ -215,16 +220,21 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
#######################################################
- # _get_some_state_from_cache tests against a full cache
+ # _get_state_for_group_using_cache tests against a full cache
#######################################################
room_id = self.room.to_string()
group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
group = list(group_ids.keys())[0]
- # test _get_some_state_from_cache correctly filters out members with types=[]
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member]
+ # test _get_state_for_group_using_cache correctly filters out members
+ # with types=[]
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache, group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
@@ -236,22 +246,27 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache,
group,
- [],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
self.assertDictEqual({}, state_dict)
- # test _get_some_state_from_cache correctly filters in members with wildcard types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with wildcard types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache,
group,
- [(EventTypes.Member, None)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: None},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
@@ -263,11 +278,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache,
group,
- [(EventTypes.Member, None)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: None},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
@@ -280,12 +297,15 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- # test _get_some_state_from_cache correctly filters in members with specific types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with specific types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache,
group,
- [(EventTypes.Member, e5.state_key)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
@@ -297,23 +317,27 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache,
group,
- [(EventTypes.Member, e5.state_key)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
- # test _get_some_state_from_cache correctly filters in members with specific types
- # and no filtered_types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with specific types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache,
group,
- [(EventTypes.Member, e5.state_key)],
- filtered_types=None,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=False,
+ ),
)
self.assertEqual(is_all, True)
@@ -357,42 +381,54 @@ class StateStoreTestCase(tests.unittest.TestCase):
############################################
# test that things work with a partial cache
- # test _get_some_state_from_cache correctly filters out members with types=[]
+ # test _get_state_for_group_using_cache correctly filters out members
+ # with types=[]
room_id = self.room.to_string()
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member]
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache, group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string()
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache,
group,
- [],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
self.assertDictEqual({}, state_dict)
- # test _get_some_state_from_cache correctly filters in members wildcard types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ # test _get_state_for_group_using_cache correctly filters in members
+ # wildcard types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache,
group,
- [(EventTypes.Member, None)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: None},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache,
group,
- [(EventTypes.Member, None)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: None},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
@@ -404,44 +440,53 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- # test _get_some_state_from_cache correctly filters in members with specific types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with specific types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache,
group,
- [(EventTypes.Member, e5.state_key)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache,
group,
- [(EventTypes.Member, e5.state_key)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
- # test _get_some_state_from_cache correctly filters in members with specific types
- # and no filtered_types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with specific types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache,
group,
- [(EventTypes.Member, e5.state_key)],
- filtered_types=None,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=False,
+ ),
)
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache,
group,
- [(EventTypes.Member, e5.state_key)],
- filtered_types=None,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=False,
+ ),
)
self.assertEqual(is_all, True)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 952a0a7b51..1a5dc32c88 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -112,7 +112,7 @@ class MessageAcceptTests(unittest.TestCase):
"origin_server_ts": 1,
"type": "m.room.message",
"origin": "test.serv",
- "content": "hewwo?",
+ "content": {"body": "hewwo?"},
"auth_events": [],
"prev_events": [("two:test.serv", {}), (most_recent, {})],
}
@@ -123,8 +123,8 @@ class MessageAcceptTests(unittest.TestCase):
"test.serv", lying_event, sent_to_us_directly=True
)
- # Step the reactor, so the database fetches come back
- self.reactor.advance(1)
+ # Step the reactor, so the database fetches come back
+ self.reactor.advance(1)
# on_receive_pdu should throw an error
failure = self.failureResultOf(d)
diff --git a/tests/test_mau.py b/tests/test_mau.py
index bdbacb8448..04f95c942f 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -21,30 +21,20 @@ from mock import Mock, NonCallableMock
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
-from synapse.http.server import JsonResource
from synapse.rest.client.v2_alpha import register, sync
-from synapse.util import Clock
from tests import unittest
-from tests.server import (
- ThreadedMemoryReactorClock,
- make_request,
- render,
- setup_test_homeserver,
-)
-class TestMauLimit(unittest.TestCase):
- def setUp(self):
- self.reactor = ThreadedMemoryReactorClock()
- self.clock = Clock(self.reactor)
+class TestMauLimit(unittest.HomeserverTestCase):
- self.hs = setup_test_homeserver(
- self.addCleanup,
+ servlets = [register.register_servlets, sync.register_servlets]
+
+ def make_homeserver(self, reactor, clock):
+
+ self.hs = self.setup_test_homeserver(
"red",
http_client=None,
- clock=self.clock,
- reactor=self.reactor,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
@@ -63,10 +53,7 @@ class TestMauLimit(unittest.TestCase):
self.hs.config.server_notices_mxid_display_name = None
self.hs.config.server_notices_mxid_avatar_url = None
self.hs.config.server_notices_room_name = "Test Server Notice Room"
-
- self.resource = JsonResource(self.hs)
- register.register_servlets(self.hs, self.resource)
- sync.register_servlets(self.hs, self.resource)
+ return self.hs
def test_simple_deny_mau(self):
# Create and sync so that the MAU counts get updated
@@ -184,6 +171,24 @@ class TestMauLimit(unittest.TestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ def test_tracked_but_not_limited(self):
+ self.hs.config.max_mau_value = 1 # should not matter
+ self.hs.config.limit_usage_by_mau = False
+ self.hs.config.mau_stats_only = True
+
+ # Simply being able to create 2 users indicates that the
+ # limit was not reached.
+ token1 = self.create_user("kermit1")
+ self.do_sync_for_user(token1)
+ token2 = self.create_user("kermit2")
+ self.do_sync_for_user(token2)
+
+ # We do want to verify that the number of tracked users
+ # matches what we want though
+ count = self.store.get_monthly_active_count()
+ self.reactor.advance(100)
+ self.assertEqual(2, self.successResultOf(count))
+
def create_user(self, localpart):
request_data = json.dumps(
{
@@ -193,8 +198,8 @@ class TestMauLimit(unittest.TestCase):
}
)
- request, channel = make_request("POST", "/register", request_data)
- render(request, self.resource, self.reactor)
+ request, channel = self.make_request("POST", "/register", request_data)
+ self.render(request)
if channel.code != 200:
raise HttpResponseException(
@@ -206,10 +211,10 @@ class TestMauLimit(unittest.TestCase):
return access_token
def do_sync_for_user(self, token):
- request, channel = make_request(
- "GET", "/sync", access_token=token.encode('ascii')
+ request, channel = self.make_request(
+ "GET", "/sync", access_token=token
)
- render(request, self.resource, self.reactor)
+ self.render(request)
if channel.code != 200:
raise HttpResponseException(
diff --git a/tests/test_server.py b/tests/test_server.py
index 4045fdadc3..634a8fbca5 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -27,6 +27,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite, logger
from synapse.util import Clock
+from synapse.util.logcontext import make_deferred_yieldable
from tests import unittest
from tests.server import FakeTransport, make_request, render, setup_test_homeserver
@@ -57,7 +58,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback
)
- request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83")
+ request, channel = make_request(
+ self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
+ )
render(request, res, self.reactor)
self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
@@ -75,7 +78,7 @@ class JsonResourceTests(unittest.TestCase):
res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
- request, channel = make_request(b"GET", b"/_matrix/foo")
+ request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b'500')
@@ -93,12 +96,12 @@ class JsonResourceTests(unittest.TestCase):
d = Deferred()
d.addCallback(_throw)
self.reactor.callLater(1, d.callback, True)
- return d
+ return make_deferred_yieldable(d)
res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
- request, channel = make_request(b"GET", b"/_matrix/foo")
+ request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b'500')
@@ -115,7 +118,7 @@ class JsonResourceTests(unittest.TestCase):
res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
- request, channel = make_request(b"GET", b"/_matrix/foo")
+ request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b'403')
@@ -136,7 +139,7 @@ class JsonResourceTests(unittest.TestCase):
res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
- request, channel = make_request(b"GET", b"/_matrix/foobar")
+ request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar")
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b'400')
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
new file mode 100644
index 0000000000..0968e86a7b
--- /dev/null
+++ b/tests/test_terms_auth.py
@@ -0,0 +1,123 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+import six
+from mock import Mock
+
+from twisted.test.proto_helpers import MemoryReactorClock
+
+from synapse.rest.client.v2_alpha.register import register_servlets
+from synapse.util import Clock
+
+from tests import unittest
+
+
+class TermsTestCase(unittest.HomeserverTestCase):
+ servlets = [register_servlets]
+
+ def prepare(self, reactor, clock, hs):
+ self.clock = MemoryReactorClock()
+ self.hs_clock = Clock(self.clock)
+ self.url = "/_matrix/client/r0/register"
+ self.registration_handler = Mock()
+ self.auth_handler = Mock()
+ self.device_handler = Mock()
+ hs.config.enable_registration = True
+ hs.config.registrations_require_3pid = []
+ hs.config.auto_join_rooms = []
+ hs.config.enable_registration_captcha = False
+
+ def test_ui_auth(self):
+ self.hs.config.user_consent_at_registration = True
+ self.hs.config.user_consent_policy_name = "My Cool Privacy Policy"
+ self.hs.config.public_baseurl = "https://example.org/"
+ self.hs.config.user_consent_version = "1.0"
+
+ # Do a UI auth request
+ request, channel = self.make_request(b"POST", self.url, b"{}")
+ self.render(request)
+
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+
+ self.assertTrue(channel.json_body is not None)
+ self.assertIsInstance(channel.json_body["session"], six.text_type)
+
+ self.assertIsInstance(channel.json_body["flows"], list)
+ for flow in channel.json_body["flows"]:
+ self.assertIsInstance(flow["stages"], list)
+ self.assertTrue(len(flow["stages"]) > 0)
+ self.assertEquals(flow["stages"][-1], "m.login.terms")
+
+ expected_params = {
+ "m.login.terms": {
+ "policies": {
+ "privacy_policy": {
+ "en": {
+ "name": "My Cool Privacy Policy",
+ "url": "https://example.org/_matrix/consent?v=1.0",
+ },
+ "version": "1.0"
+ },
+ },
+ },
+ }
+ self.assertIsInstance(channel.json_body["params"], dict)
+ self.assertDictContainsSubset(channel.json_body["params"], expected_params)
+
+ # We have to complete the dummy auth stage before completing the terms stage
+ request_data = json.dumps(
+ {
+ "username": "kermit",
+ "password": "monkey",
+ "auth": {
+ "session": channel.json_body["session"],
+ "type": "m.login.dummy",
+ },
+ }
+ )
+
+ self.registration_handler.check_username = Mock(return_value=True)
+
+ request, channel = self.make_request(b"POST", self.url, request_data)
+ self.render(request)
+
+ # We don't bother checking that the response is correct - we'll leave that to
+ # other tests. We just want to make sure we're on the right path.
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+
+ # Finish the UI auth for terms
+ request_data = json.dumps(
+ {
+ "username": "kermit",
+ "password": "monkey",
+ "auth": {
+ "session": channel.json_body["session"],
+ "type": "m.login.terms",
+ },
+ }
+ )
+ request, channel = self.make_request(b"POST", self.url, request_data)
+ self.render(request)
+
+ # We're interested in getting a response that looks like a successful
+ # registration, not so much that the details are exactly what we want.
+
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.assertTrue(channel.json_body is not None)
+ self.assertIsInstance(channel.json_body["user_id"], six.text_type)
+ self.assertIsInstance(channel.json_body["access_token"], six.text_type)
+ self.assertIsInstance(channel.json_body["device_id"], six.text_type)
diff --git a/tests/unittest.py b/tests/unittest.py
index a59291cc60..092c930396 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import gc
import hashlib
import hmac
import logging
@@ -31,10 +31,12 @@ from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
-from synapse.util.logcontext import LoggingContextFilter
+from synapse.util.logcontext import LoggingContext, LoggingContextFilter
from tests.server import get_clock, make_request, render, setup_test_homeserver
-from tests.utils import default_config
+from tests.utils import default_config, setupdb
+
+setupdb()
# Set up putting Synapse's logs into Trial's.
rootLogger = logging.getLogger()
@@ -102,8 +104,16 @@ class TestCase(unittest.TestCase):
# traceback when a unit test exits leaving things on the reactor.
twisted.internet.base.DelayedCall.debug = True
- old_level = logging.getLogger().level
+ # if we're not starting in the sentinel logcontext, then to be honest
+ # all future bets are off.
+ if LoggingContext.current_context() is not LoggingContext.sentinel:
+ self.fail(
+ "Test starting with non-sentinel logging context %s" % (
+ LoggingContext.current_context(),
+ )
+ )
+ old_level = logging.getLogger().level
if old_level != level:
@around(self)
@@ -115,6 +125,16 @@ class TestCase(unittest.TestCase):
logging.getLogger().setLevel(level)
return orig()
+ @around(self)
+ def tearDown(orig):
+ ret = orig()
+ # force a GC to workaround problems with deferreds leaking logcontexts when
+ # they are GCed (see the logcontext docs)
+ gc.collect()
+ LoggingContext.set_current_context(LoggingContext.sentinel)
+
+ return ret
+
def assertObjectHasAttributes(self, attrs, obj):
"""Asserts that the given object has each of the attributes given, and
that the value of each matches according to assertEquals."""
@@ -146,6 +166,13 @@ def DEBUG(target):
return target
+def INFO(target):
+ """A decorator to set the .loglevel attribute to logging.INFO.
+ Can apply to either a TestCase or an individual test method."""
+ target.loglevel = logging.INFO
+ return target
+
+
class HomeserverTestCase(TestCase):
"""
A base TestCase that reduces boilerplate for HomeServer-using test cases.
@@ -182,11 +209,11 @@ class HomeserverTestCase(TestCase):
for servlet in self.servlets:
servlet(self.hs, self.resource)
- if hasattr(self, "user_id"):
- from tests.rest.client.v1.utils import RestHelper
+ from tests.rest.client.v1.utils import RestHelper
- self.helper = RestHelper(self.hs, self.resource, self.user_id)
+ self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
+ if hasattr(self, "user_id"):
if self.hijack_auth:
def get_user_by_access_token(token=None, allow_guest=False):
@@ -251,7 +278,13 @@ class HomeserverTestCase(TestCase):
"""
def make_request(
- self, method, path, content=b"", access_token=None, request=SynapseRequest
+ self,
+ method,
+ path,
+ content=b"",
+ access_token=None,
+ request=SynapseRequest,
+ shorthand=True,
):
"""
Create a SynapseRequest at the path using the method and containing the
@@ -263,6 +296,8 @@ class HomeserverTestCase(TestCase):
escaped UTF-8 & spaces and such).
content (bytes or dict): The body of the request. JSON-encoded, if
a dict.
+ shorthand: Whether to try and be helpful and prefix the given URL
+ with the usual REST API path, if it doesn't contain it.
Returns:
A synapse.http.site.SynapseRequest.
@@ -270,7 +305,9 @@ class HomeserverTestCase(TestCase):
if isinstance(content, dict):
content = json.dumps(content).encode('utf8')
- return make_request(method, path, content, access_token, request)
+ return make_request(
+ self.reactor, method, path, content, access_token, request, shorthand
+ )
def render(self, request):
"""
@@ -373,5 +410,5 @@ class HomeserverTestCase(TestCase):
self.render(request)
self.assertEqual(channel.code, 200)
- access_token = channel.json_body["access_token"].encode('ascii')
+ access_token = channel.json_body["access_token"]
return access_token
diff --git a/tests/utils.py b/tests/utils.py
index dd347a0c59..52ab762010 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -123,7 +123,10 @@ def default_config(name):
config.user_directory_search_all_users = False
config.user_consent_server_notice_content = None
config.block_events_without_consent_error = None
+ config.user_consent_at_registration = False
+ config.user_consent_policy_name = "Privacy Policy"
config.media_storage_providers = []
+ config.autocreate_auto_join_rooms = True
config.auto_join_rooms = []
config.limit_usage_by_mau = False
config.hs_disabled = False
@@ -131,6 +134,7 @@ def default_config(name):
config.hs_disabled_limit_type = ""
config.max_mau_value = 50
config.mau_trial_days = 0
+ config.mau_stats_only = False
config.mau_limits_reserved_threepids = []
config.admin_contact = None
config.rc_messages_per_second = 10000
|