diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 022d81ce3e..379e9c4ab1 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -457,8 +457,8 @@ class AuthTestCase(unittest.TestCase):
with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking()
- self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
+ self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
+ self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
# Ensure does not throw an error
@@ -468,11 +468,37 @@ class AuthTestCase(unittest.TestCase):
yield self.auth.check_auth_blocking()
@defer.inlineCallbacks
+ def test_reserved_threepid(self):
+ self.hs.config.limit_usage_by_mau = True
+ self.hs.config.max_mau_value = 1
+ self.store.get_monthly_active_count = lambda: defer.succeed(2)
+ threepid = {'medium': 'email', 'address': 'reserved@server.com'}
+ unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'}
+ self.hs.config.mau_limits_reserved_threepids = [threepid]
+
+ yield self.store.register(user_id='user1', token="123", password_hash=None)
+ with self.assertRaises(ResourceLimitError):
+ yield self.auth.check_auth_blocking()
+
+ with self.assertRaises(ResourceLimitError):
+ yield self.auth.check_auth_blocking(threepid=unknown_threepid)
+
+ yield self.auth.check_auth_blocking(threepid=threepid)
+
+ @defer.inlineCallbacks
def test_hs_disabled(self):
self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking()
- self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
+ self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
+ self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
+
+ @defer.inlineCallbacks
+ def test_server_notices_mxid_special_cased(self):
+ self.hs.config.hs_disabled = True
+ user = "@user:server"
+ self.hs.config.server_notices_mxid = user
+ self.hs.config.hs_disabled_message = "Reason for being disabled"
+ yield self.auth.check_auth_blocking(user)
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 48b2d3d663..2a7044801a 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -60,7 +60,7 @@ class FilteringTestCase(unittest.TestCase):
invalid_filters = [
{"boom": {}},
{"account_data": "Hello World"},
- {"event_fields": ["\\foo"]},
+ {"event_fields": [r"\\foo"]},
{"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
{"event_format": "other"},
{"room": {"not_rooms": ["#foo:pik-test"]}},
@@ -109,6 +109,16 @@ class FilteringTestCase(unittest.TestCase):
"event_format": "client",
"event_fields": ["type", "content", "sender"],
},
+
+ # a single backslash should be permitted (though it is debatable whether
+ # it should be permitted before anything other than `.`, and what that
+ # actually means)
+ #
+ # (note that event_fields is implemented in
+ # synapse.events.utils.serialize_event, and so whether this actually works
+ # is tested elsewhere. We just want to check that it is allowed through the
+ # filter validation)
+ {"event_fields": [r"foo\.bar"]},
]
for filter in valid_filters:
try:
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index 76b5090fff..a83f567ebd 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -47,7 +47,7 @@ class FrontendProxyTests(HomeserverTestCase):
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
self.resource = (
- site.resource.children["_matrix"].children["client"].children["r0"]
+ site.resource.children[b"_matrix"].children[b"client"].children[b"r0"]
)
request, channel = self.make_request("PUT", "presence/a/status")
@@ -77,7 +77,7 @@ class FrontendProxyTests(HomeserverTestCase):
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
self.resource = (
- site.resource.children["_matrix"].children["client"].children["r0"]
+ site.resource.children[b"_matrix"].children[b"client"].children[b"r0"]
)
request, channel = self.make_request("PUT", "presence/a/status")
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index f88d28a19d..0c23068bcf 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -67,6 +67,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
with open(log_config_file) as f:
config = f.read()
# find the 'filename' line
- matches = re.findall("^\s*filename:\s*(.*)$", config, re.M)
+ matches = re.findall(r"^\s*filename:\s*(.*)$", config, re.M)
self.assertEqual(1, len(matches))
self.assertEqual(matches[0], expected)
diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py
new file mode 100644
index 0000000000..f37a17d618
--- /dev/null
+++ b/tests/config/test_room_directory.py
@@ -0,0 +1,67 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import yaml
+
+from synapse.config.room_directory import RoomDirectoryConfig
+
+from tests import unittest
+
+
+class RoomDirectoryConfigTestCase(unittest.TestCase):
+ def test_alias_creation_acl(self):
+ config = yaml.load("""
+ alias_creation_rules:
+ - user_id: "*bob*"
+ alias: "*"
+ action: "deny"
+ - user_id: "*"
+ alias: "#unofficial_*"
+ action: "allow"
+ - user_id: "@foo*:example.com"
+ alias: "*"
+ action: "allow"
+ - user_id: "@gah:example.com"
+ alias: "#goo:example.com"
+ action: "allow"
+ """)
+
+ rd_config = RoomDirectoryConfig()
+ rd_config.read_config(config)
+
+ self.assertFalse(rd_config.is_alias_creation_allowed(
+ user_id="@bob:example.com",
+ alias="#test:example.com",
+ ))
+
+ self.assertTrue(rd_config.is_alias_creation_allowed(
+ user_id="@test:example.com",
+ alias="#unofficial_st:example.com",
+ ))
+
+ self.assertTrue(rd_config.is_alias_creation_allowed(
+ user_id="@foobar:example.com",
+ alias="#test:example.com",
+ ))
+
+ self.assertTrue(rd_config.is_alias_creation_allowed(
+ user_id="@gah:example.com",
+ alias="#goo:example.com",
+ ))
+
+ self.assertFalse(rd_config.is_alias_creation_allowed(
+ user_id="@test:example.com",
+ alias="#test:example.com",
+ ))
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index ff217ca8b9..d0cc492deb 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -156,7 +156,7 @@ class SerializeEventTestCase(unittest.TestCase):
room_id="!foo:bar",
content={"key.with.dots": {}},
),
- ["content.key\.with\.dots"],
+ [r"content.key\.with\.dots"],
),
{"content": {"key.with.dots": {}}},
)
@@ -172,7 +172,7 @@ class SerializeEventTestCase(unittest.TestCase):
"nested.dot.key": {"leaf.key": 42, "not_me_either": 1},
},
),
- ["content.nested\.dot\.key.leaf\.key"],
+ [r"content.nested\.dot\.key.leaf\.key"],
),
{"content": {"nested.dot.key": {"leaf.key": 42}}},
)
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 56e7acd37c..a3aa0a1cf2 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,79 +14,79 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
import synapse.api.errors
import synapse.handlers.device
import synapse.storage
-from tests import unittest, utils
+from tests import unittest
user1 = "@boris:aaa"
user2 = "@theresa:bbb"
-class DeviceTestCase(unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super(DeviceTestCase, self).__init__(*args, **kwargs)
- self.store = None # type: synapse.storage.DataStore
- self.handler = None # type: synapse.handlers.device.DeviceHandler
- self.clock = None # type: utils.MockClock
-
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield utils.setup_test_homeserver(self.addCleanup)
+class DeviceTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver("server", http_client=None)
self.handler = hs.get_device_handler()
self.store = hs.get_datastore()
- self.clock = hs.get_clock()
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ # These tests assume that it starts 1000 seconds in.
+ self.reactor.advance(1000)
- @defer.inlineCallbacks
def test_device_is_created_if_doesnt_exist(self):
- res = yield self.handler.check_device_registered(
- user_id="@boris:foo",
- device_id="fco",
- initial_device_display_name="display name",
+ res = self.get_success(
+ self.handler.check_device_registered(
+ user_id="@boris:foo",
+ device_id="fco",
+ initial_device_display_name="display name",
+ )
)
self.assertEqual(res, "fco")
- dev = yield self.handler.store.get_device("@boris:foo", "fco")
+ dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name")
- @defer.inlineCallbacks
def test_device_is_preserved_if_exists(self):
- res1 = yield self.handler.check_device_registered(
- user_id="@boris:foo",
- device_id="fco",
- initial_device_display_name="display name",
+ res1 = self.get_success(
+ self.handler.check_device_registered(
+ user_id="@boris:foo",
+ device_id="fco",
+ initial_device_display_name="display name",
+ )
)
self.assertEqual(res1, "fco")
- res2 = yield self.handler.check_device_registered(
- user_id="@boris:foo",
- device_id="fco",
- initial_device_display_name="new display name",
+ res2 = self.get_success(
+ self.handler.check_device_registered(
+ user_id="@boris:foo",
+ device_id="fco",
+ initial_device_display_name="new display name",
+ )
)
self.assertEqual(res2, "fco")
- dev = yield self.handler.store.get_device("@boris:foo", "fco")
+ dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name")
- @defer.inlineCallbacks
def test_device_id_is_made_up_if_unspecified(self):
- device_id = yield self.handler.check_device_registered(
- user_id="@theresa:foo",
- device_id=None,
- initial_device_display_name="display",
+ device_id = self.get_success(
+ self.handler.check_device_registered(
+ user_id="@theresa:foo",
+ device_id=None,
+ initial_device_display_name="display",
+ )
)
- dev = yield self.handler.store.get_device("@theresa:foo", device_id)
+ dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
self.assertEqual(dev["display_name"], "display")
- @defer.inlineCallbacks
def test_get_devices_by_user(self):
- yield self._record_users()
+ self._record_users()
+
+ res = self.get_success(self.handler.get_devices_by_user(user1))
- res = yield self.handler.get_devices_by_user(user1)
self.assertEqual(3, len(res))
device_map = {d["device_id"]: d for d in res}
self.assertDictContainsSubset(
@@ -119,11 +120,10 @@ class DeviceTestCase(unittest.TestCase):
device_map["abc"],
)
- @defer.inlineCallbacks
def test_get_device(self):
- yield self._record_users()
+ self._record_users()
- res = yield self.handler.get_device(user1, "abc")
+ res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertDictContainsSubset(
{
"user_id": user1,
@@ -135,59 +135,66 @@ class DeviceTestCase(unittest.TestCase):
res,
)
- @defer.inlineCallbacks
def test_delete_device(self):
- yield self._record_users()
+ self._record_users()
# delete the device
- yield self.handler.delete_device(user1, "abc")
+ self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted
- with self.assertRaises(synapse.api.errors.NotFoundError):
- yield self.handler.get_device(user1, "abc")
+ res = self.handler.get_device(user1, "abc")
+ self.pump()
+ self.assertIsInstance(
+ self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+ )
# we'd like to check the access token was invalidated, but that's a
# bit of a PITA.
- @defer.inlineCallbacks
def test_update_device(self):
- yield self._record_users()
+ self._record_users()
update = {"display_name": "new display"}
- yield self.handler.update_device(user1, "abc", update)
+ self.get_success(self.handler.update_device(user1, "abc", update))
- res = yield self.handler.get_device(user1, "abc")
+ res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "new display")
- @defer.inlineCallbacks
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
- with self.assertRaises(synapse.api.errors.NotFoundError):
- yield self.handler.update_device("user_id", "unknown_device_id", update)
+ res = self.handler.update_device("user_id", "unknown_device_id", update)
+ self.pump()
+ self.assertIsInstance(
+ self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+ )
- @defer.inlineCallbacks
def _record_users(self):
# check this works for both devices which have a recorded client_ip,
# and those which don't.
- yield self._record_user(user1, "xyz", "display 0")
- yield self._record_user(user1, "fco", "display 1", "token1", "ip1")
- yield self._record_user(user1, "abc", "display 2", "token2", "ip2")
- yield self._record_user(user1, "abc", "display 2", "token3", "ip3")
+ self._record_user(user1, "xyz", "display 0")
+ self._record_user(user1, "fco", "display 1", "token1", "ip1")
+ self._record_user(user1, "abc", "display 2", "token2", "ip2")
+ self._record_user(user1, "abc", "display 2", "token3", "ip3")
+
+ self._record_user(user2, "def", "dispkay", "token4", "ip4")
- yield self._record_user(user2, "def", "dispkay", "token4", "ip4")
+ self.reactor.advance(10000)
- @defer.inlineCallbacks
def _record_user(
self, user_id, device_id, display_name, access_token=None, ip=None
):
- device_id = yield self.handler.check_device_registered(
- user_id=user_id,
- device_id=device_id,
- initial_device_display_name=display_name,
+ device_id = self.get_success(
+ self.handler.check_device_registered(
+ user_id=user_id,
+ device_id=device_id,
+ initial_device_display_name=display_name,
+ )
)
if ip is not None:
- yield self.store.insert_client_ip(
- user_id, access_token, ip, "user_agent", device_id
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, access_token, ip, "user_agent", device_id
+ )
)
- self.clock.advance_time(1000)
+ self.reactor.advance(1000)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index ec7355688b..8ae6556c0a 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -18,7 +18,9 @@ from mock import Mock
from twisted.internet import defer
+from synapse.config.room_directory import RoomDirectoryConfig
from synapse.handlers.directory import DirectoryHandler
+from synapse.rest.client.v1 import directory, room
from synapse.types import RoomAlias
from tests import unittest
@@ -102,3 +104,49 @@ class DirectoryTestCase(unittest.TestCase):
)
self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
+
+
+class TestCreateAliasACL(unittest.HomeserverTestCase):
+ user_id = "@test:test"
+
+ servlets = [directory.register_servlets, room.register_servlets]
+
+ def prepare(self, hs, reactor, clock):
+ # We cheekily override the config to add custom alias creation rules
+ config = {}
+ config["alias_creation_rules"] = [
+ {
+ "user_id": "*",
+ "alias": "#unofficial_*",
+ "action": "allow",
+ }
+ ]
+
+ rd_config = RoomDirectoryConfig()
+ rd_config.read_config(config)
+
+ self.hs.config.is_alias_creation_allowed = rd_config.is_alias_creation_allowed
+
+ return hs
+
+ def test_denied(self):
+ room_id = self.helper.create_room_as(self.user_id)
+
+ request, channel = self.make_request(
+ "PUT",
+ b"directory/room/%23test%3Atest",
+ ('{"room_id":"%s"}' % (room_id,)).encode('ascii'),
+ )
+ self.render(request)
+ self.assertEquals(403, channel.code, channel.result)
+
+ def test_allowed(self):
+ room_id = self.helper.create_room_as(self.user_id)
+
+ request, channel = self.make_request(
+ "PUT",
+ b"directory/room/%23unofficial_test%3Atest",
+ ('{"room_id":"%s"}' % (room_id,)).encode('ascii'),
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
new file mode 100644
index 0000000000..9e08eac0a5
--- /dev/null
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -0,0 +1,397 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+
+import mock
+
+from twisted.internet import defer
+
+import synapse.api.errors
+import synapse.handlers.e2e_room_keys
+import synapse.storage
+from synapse.api import errors
+
+from tests import unittest, utils
+
+# sample room_key data for use in the tests
+room_keys = {
+ "rooms": {
+ "!abc:matrix.org": {
+ "sessions": {
+ "c0ff33": {
+ "first_message_index": 1,
+ "forwarded_count": 1,
+ "is_verified": False,
+ "session_data": "SSBBTSBBIEZJU0gK"
+ }
+ }
+ }
+ }
+}
+
+
+class E2eRoomKeysHandlerTestCase(unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs)
+ self.hs = None # type: synapse.server.HomeServer
+ self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield utils.setup_test_homeserver(
+ self.addCleanup,
+ handlers=None,
+ replication_layer=mock.Mock(),
+ )
+ self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
+ self.local_user = "@boris:" + self.hs.hostname
+
+ @defer.inlineCallbacks
+ def test_get_missing_current_version_info(self):
+ """Check that we get a 404 if we ask for info about the current version
+ if there is no version.
+ """
+ res = None
+ try:
+ yield self.handler.get_version_info(self.local_user)
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ @defer.inlineCallbacks
+ def test_get_missing_version_info(self):
+ """Check that we get a 404 if we ask for info about a specific version
+ if it doesn't exist.
+ """
+ res = None
+ try:
+ yield self.handler.get_version_info(self.local_user, "bogus_version")
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ @defer.inlineCallbacks
+ def test_create_version(self):
+ """Check that we can create and then retrieve versions.
+ """
+ res = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(res, "1")
+
+ # check we can retrieve it as the current version
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertDictEqual(res, {
+ "version": "1",
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+
+ # check we can retrieve it as a specific version
+ res = yield self.handler.get_version_info(self.local_user, "1")
+ self.assertDictEqual(res, {
+ "version": "1",
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+
+ # upload a new one...
+ res = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ })
+ self.assertEqual(res, "2")
+
+ # check we can retrieve it as the current version
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertDictEqual(res, {
+ "version": "2",
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ })
+
+ @defer.inlineCallbacks
+ def test_delete_missing_version(self):
+ """Check that we get a 404 on deleting nonexistent versions
+ """
+ res = None
+ try:
+ yield self.handler.delete_version(self.local_user, "1")
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ @defer.inlineCallbacks
+ def test_delete_missing_current_version(self):
+ """Check that we get a 404 on deleting nonexistent current version
+ """
+ res = None
+ try:
+ yield self.handler.delete_version(self.local_user)
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ @defer.inlineCallbacks
+ def test_delete_version(self):
+ """Check that we can create and then delete versions.
+ """
+ res = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(res, "1")
+
+ # check we can delete it
+ yield self.handler.delete_version(self.local_user, "1")
+
+ # check that it's gone
+ res = None
+ try:
+ yield self.handler.get_version_info(self.local_user, "1")
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ @defer.inlineCallbacks
+ def test_get_missing_room_keys(self):
+ """Check that we get a 404 on querying missing room_keys
+ """
+ res = None
+ try:
+ yield self.handler.get_room_keys(self.local_user, "bogus_version")
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ # check we also get a 404 even if the version is valid
+ version = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(version, "1")
+
+ res = None
+ try:
+ yield self.handler.get_room_keys(self.local_user, version)
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ # TODO: test the locking semantics when uploading room_keys,
+ # although this is probably best done in sytest
+
+ @defer.inlineCallbacks
+ def test_upload_room_keys_no_versions(self):
+ """Check that we get a 404 on uploading keys when no versions are defined
+ """
+ res = None
+ try:
+ yield self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ @defer.inlineCallbacks
+ def test_upload_room_keys_bogus_version(self):
+ """Check that we get a 404 on uploading keys when an nonexistent version
+ is specified
+ """
+ version = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(version, "1")
+
+ res = None
+ try:
+ yield self.handler.upload_room_keys(
+ self.local_user, "bogus_version", room_keys
+ )
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ @defer.inlineCallbacks
+ def test_upload_room_keys_wrong_version(self):
+ """Check that we get a 403 on uploading keys for an old version
+ """
+ version = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(version, "1")
+
+ version = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "second_version_auth_data",
+ })
+ self.assertEqual(version, "2")
+
+ res = None
+ try:
+ yield self.handler.upload_room_keys(self.local_user, "1", room_keys)
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 403)
+
+ @defer.inlineCallbacks
+ def test_upload_room_keys_insert(self):
+ """Check that we can insert and retrieve keys for a session
+ """
+ version = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(version, "1")
+
+ yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+
+ res = yield self.handler.get_room_keys(self.local_user, version)
+ self.assertDictEqual(res, room_keys)
+
+ # check getting room_keys for a given room
+ res = yield self.handler.get_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org"
+ )
+ self.assertDictEqual(res, room_keys)
+
+ # check getting room_keys for a given session_id
+ res = yield self.handler.get_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org",
+ session_id="c0ff33",
+ )
+ self.assertDictEqual(res, room_keys)
+
+ @defer.inlineCallbacks
+ def test_upload_room_keys_merge(self):
+ """Check that we can upload a new room_key for an existing session and
+ have it correctly merged"""
+ version = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(version, "1")
+
+ yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+
+ new_room_keys = copy.deepcopy(room_keys)
+ new_room_key = new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33']
+
+ # test that increasing the message_index doesn't replace the existing session
+ new_room_key['first_message_index'] = 2
+ new_room_key['session_data'] = 'new'
+ yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+
+ res = yield self.handler.get_room_keys(self.local_user, version)
+ self.assertEqual(
+ res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'],
+ "SSBBTSBBIEZJU0gK"
+ )
+
+ # test that marking the session as verified however /does/ replace it
+ new_room_key['is_verified'] = True
+ yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+
+ res = yield self.handler.get_room_keys(self.local_user, version)
+ self.assertEqual(
+ res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'],
+ "new"
+ )
+
+ # test that a session with a higher forwarded_count doesn't replace one
+ # with a lower forwarding count
+ new_room_key['forwarded_count'] = 2
+ new_room_key['session_data'] = 'other'
+ yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+
+ res = yield self.handler.get_room_keys(self.local_user, version)
+ self.assertEqual(
+ res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'],
+ "new"
+ )
+
+ # TODO: check edge cases as well as the common variations here
+
+ @defer.inlineCallbacks
+ def test_delete_room_keys(self):
+ """Check that we can insert and delete keys for a session
+ """
+ version = yield self.handler.create_version(self.local_user, {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "first_version_auth_data",
+ })
+ self.assertEqual(version, "1")
+
+ # check for bulk-delete
+ yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ yield self.handler.delete_room_keys(self.local_user, version)
+ res = None
+ try:
+ yield self.handler.get_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org",
+ session_id="c0ff33",
+ )
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ # check for bulk-delete per room
+ yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ yield self.handler.delete_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org",
+ )
+ res = None
+ try:
+ yield self.handler.get_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org",
+ session_id="c0ff33",
+ )
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
+
+ # check for bulk-delete per session
+ yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ yield self.handler.delete_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org",
+ session_id="c0ff33",
+ )
+ res = None
+ try:
+ yield self.handler.get_room_keys(
+ self.local_user,
+ version,
+ room_id="!abc:matrix.org",
+ session_id="c0ff33",
+ )
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 404)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 40d9aca671..dc140570c6 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -20,7 +20,7 @@ from twisted.internet import defer
import synapse.types
from synapse.api.errors import AuthError
-from synapse.handlers.profile import ProfileHandler
+from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID
from tests import unittest
@@ -29,7 +29,7 @@ from tests.utils import setup_test_homeserver
class ProfileHandlers(object):
def __init__(self, hs):
- self.profile_handler = ProfileHandler(hs)
+ self.profile_handler = MasterProfileHandler(hs)
class ProfileTestCase(unittest.TestCase):
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 7b4ade3dfb..3e9a190727 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api.errors import ResourceLimitError
from synapse.handlers.register import RegistrationHandler
-from synapse.types import UserID, create_requester
+from synapse.types import RoomAlias, UserID, create_requester
from tests.utils import setup_test_homeserver
@@ -41,30 +41,27 @@ class RegistrationTestCase(unittest.TestCase):
self.mock_captcha_client = Mock()
self.hs = yield setup_test_homeserver(
self.addCleanup,
- handlers=None,
- http_client=None,
expire_access_token=True,
- profile_handler=Mock(),
)
self.macaroon_generator = Mock(
generate_access_token=Mock(return_value='secret')
)
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
- self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler
self.store = self.hs.get_datastore()
self.hs.config.max_mau_value = 50
self.lots_of_users = 100
self.small_number_of_users = 1
+ self.requester = create_requester("@requester:test")
+
@defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
- local_part = "someone"
- display_name = "someone"
- user_id = "@someone:test"
- requester = create_requester("@as:test")
+ frank = UserID.from_string("@frank:test")
+ user_id = frank.to_string()
+ requester = create_requester(user_id)
result_user_id, result_token = yield self.handler.get_or_create_user(
- requester, local_part, display_name
+ requester, frank.localpart, "Frankie"
)
self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret')
@@ -78,12 +75,11 @@ class RegistrationTestCase(unittest.TestCase):
token="jkv;g498752-43gj['eamb!-5",
password_hash=None,
)
- local_part = "frank"
- display_name = "Frank"
- user_id = "@frank:test"
- requester = create_requester("@as:test")
+ local_part = frank.localpart
+ user_id = frank.to_string()
+ requester = create_requester(user_id)
result_user_id, result_token = yield self.handler.get_or_create_user(
- requester, local_part, display_name
+ requester, local_part, None
)
self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret')
@@ -92,7 +88,7 @@ class RegistrationTestCase(unittest.TestCase):
def test_mau_limits_when_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
- yield self.handler.get_or_create_user("requester", 'a', "display_name")
+ yield self.handler.get_or_create_user(self.requester, 'a', "display_name")
@defer.inlineCallbacks
def test_get_or_create_user_mau_not_blocked(self):
@@ -101,7 +97,7 @@ class RegistrationTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.config.max_mau_value - 1)
)
# Ensure does not throw exception
- yield self.handler.get_or_create_user("@user:server", 'c', "User")
+ yield self.handler.get_or_create_user(self.requester, 'c', "User")
@defer.inlineCallbacks
def test_get_or_create_user_mau_blocked(self):
@@ -110,13 +106,13 @@ class RegistrationTestCase(unittest.TestCase):
return_value=defer.succeed(self.lots_of_users)
)
with self.assertRaises(ResourceLimitError):
- yield self.handler.get_or_create_user("requester", 'b', "display_name")
+ yield self.handler.get_or_create_user(self.requester, 'b', "display_name")
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(ResourceLimitError):
- yield self.handler.get_or_create_user("requester", 'b', "display_name")
+ yield self.handler.get_or_create_user(self.requester, 'b', "display_name")
@defer.inlineCallbacks
def test_register_mau_blocked(self):
@@ -147,3 +143,44 @@ class RegistrationTestCase(unittest.TestCase):
)
with self.assertRaises(ResourceLimitError):
yield self.handler.register_saml2(localpart="local_part")
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_rooms(self):
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+ res = yield self.handler.register(localpart='jeff')
+ rooms = yield self.store.get_rooms_for_user(res[0])
+
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ room_id = yield directory_handler.get_association(room_alias)
+
+ self.assertTrue(room_id['room_id'] in rooms)
+ self.assertEqual(len(rooms), 1)
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_rooms_with_no_rooms(self):
+ self.hs.config.auto_join_rooms = []
+ frank = UserID.from_string("@frank:test")
+ res = yield self.handler.register(frank.localpart)
+ self.assertEqual(res[0], frank.to_string())
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ self.assertEqual(len(rooms), 0)
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_where_room_is_another_domain(self):
+ self.hs.config.auto_join_rooms = ["#room:another"]
+ frank = UserID.from_string("@frank:test")
+ res = yield self.handler.register(frank.localpart)
+ self.assertEqual(res[0], frank.to_string())
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ self.assertEqual(len(rooms), 0)
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_where_auto_create_is_false(self):
+ self.hs.config.autocreate_auto_join_rooms = False
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+ res = yield self.handler.register(localpart='jeff')
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ self.assertEqual(len(rooms), 0)
diff --git a/tests/handlers/test_roomlist.py b/tests/handlers/test_roomlist.py
new file mode 100644
index 0000000000..61eebb6985
--- /dev/null
+++ b/tests/handlers/test_roomlist.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.handlers.room_list import RoomListNextBatch
+
+import tests.unittest
+import tests.utils
+
+
+class RoomListTestCase(tests.unittest.TestCase):
+ """ Tests RoomList's RoomListNextBatch. """
+
+ def setUp(self):
+ pass
+
+ def test_check_read_batch_tokens(self):
+ batch_token = RoomListNextBatch(
+ stream_ordering="abcdef",
+ public_room_stream_id="123",
+ current_limit=20,
+ direction_is_forward=True,
+ ).to_token()
+ next_batch = RoomListNextBatch.from_token(batch_token)
+ self.assertEquals(next_batch.stream_ordering, "abcdef")
+ self.assertEquals(next_batch.public_room_stream_id, "123")
+ self.assertEquals(next_batch.current_limit, 20)
+ self.assertEquals(next_batch.direction_is_forward, True)
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index a01ab471f5..31f54bbd7d 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -51,7 +51,7 @@ class SyncTestCase(tests.unittest.TestCase):
self.hs.config.hs_disabled = True
with self.assertRaises(ResourceLimitError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
+ self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.hs.config.hs_disabled = False
@@ -59,7 +59,7 @@ class SyncTestCase(tests.unittest.TestCase):
with self.assertRaises(ResourceLimitError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
+ self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def _generate_sync_config(self, user_id):
return SyncConfig(
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index ad58073a14..36e136cded 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -33,7 +33,7 @@ from ..utils import (
)
-def _expect_edu(destination, edu_type, content, origin="test"):
+def _expect_edu_transaction(edu_type, content, origin="test"):
return {
"origin": origin,
"origin_server_ts": 1000000,
@@ -42,10 +42,8 @@ def _expect_edu(destination, edu_type, content, origin="test"):
}
-def _make_edu_json(origin, edu_type, content):
- return json.dumps(_expect_edu("test", edu_type, content, origin=origin)).encode(
- 'utf8'
- )
+def _make_edu_transaction_json(edu_type, content):
+ return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8')
class TypingNotificationsTestCase(unittest.TestCase):
@@ -190,8 +188,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
call(
"farm",
path="/_matrix/federation/v1/send/1000000/",
- data=_expect_edu(
- "farm",
+ data=_expect_edu_transaction(
"m.typing",
content={
"room_id": self.room_id,
@@ -221,11 +218,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
- yield self.mock_federation_resource.trigger(
+ (code, response) = yield self.mock_federation_resource.trigger(
"PUT",
"/_matrix/federation/v1/send/1000000/",
- _make_edu_json(
- "farm",
+ _make_edu_transaction_json(
"m.typing",
content={
"room_id": self.room_id,
@@ -233,7 +229,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
"typing": True,
},
),
- federation_auth=True,
+ federation_auth_origin=b'farm',
)
self.on_new_event.assert_has_calls(
@@ -264,8 +260,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
call(
"farm",
path="/_matrix/federation/v1/send/1000000/",
- data=_expect_edu(
- "farm",
+ data=_expect_edu_transaction(
"m.typing",
content={
"room_id": self.room_id,
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
new file mode 100644
index 0000000000..f3cb1423f0
--- /dev/null
+++ b/tests/http/test_fedclient.py
@@ -0,0 +1,190 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet.defer import TimeoutError
+from twisted.internet.error import ConnectingCancelledError, DNSLookupError
+from twisted.web.client import ResponseNeverReceived
+from twisted.web.http import HTTPChannel
+
+from synapse.http.matrixfederationclient import (
+ MatrixFederationHttpClient,
+ MatrixFederationRequest,
+)
+
+from tests.server import FakeTransport
+from tests.unittest import HomeserverTestCase
+
+
+class FederationClientTests(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+
+ hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
+ hs.tls_client_options_factory = None
+ return hs
+
+ def prepare(self, reactor, clock, homeserver):
+
+ self.cl = MatrixFederationHttpClient(self.hs)
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ def test_dns_error(self):
+ """
+ If the DNS raising returns an error, it will bubble up.
+ """
+ d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+ self.pump()
+
+ f = self.failureResultOf(d)
+ self.assertIsInstance(f.value, DNSLookupError)
+
+ def test_client_never_connect(self):
+ """
+ If the HTTP request is not connected and is timed out, it'll give a
+ ConnectingCancelledError or TimeoutError.
+ """
+ d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertFalse(d.called)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ self.assertEqual(clients[0][0], '1.2.3.4')
+ self.assertEqual(clients[0][1], 8008)
+
+ # Deferred is still without a result
+ self.assertFalse(d.called)
+
+ # Push by enough to time it out
+ self.reactor.advance(10.5)
+ f = self.failureResultOf(d)
+
+ self.assertIsInstance(f.value, (ConnectingCancelledError, TimeoutError))
+
+ def test_client_connect_no_response(self):
+ """
+ If the HTTP request is connected, but gets no response before being
+ timed out, it'll give a ResponseNeverReceived.
+ """
+ d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertFalse(d.called)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ self.assertEqual(clients[0][0], '1.2.3.4')
+ self.assertEqual(clients[0][1], 8008)
+
+ conn = Mock()
+ client = clients[0][2].buildProtocol(None)
+ client.makeConnection(conn)
+
+ # Deferred is still without a result
+ self.assertFalse(d.called)
+
+ # Push by enough to time it out
+ self.reactor.advance(10.5)
+ f = self.failureResultOf(d)
+
+ self.assertIsInstance(f.value, ResponseNeverReceived)
+
+ def test_client_gets_headers(self):
+ """
+ Once the client gets the headers, _request returns successfully.
+ """
+ request = MatrixFederationRequest(
+ method="GET",
+ destination="testserv:8008",
+ path="foo/bar",
+ )
+ d = self.cl._send_request(request, timeout=10000)
+
+ self.pump()
+
+ conn = Mock()
+ clients = self.reactor.tcpClients
+ client = clients[0][2].buildProtocol(None)
+ client.makeConnection(conn)
+
+ # Deferred does not have a result
+ self.assertFalse(d.called)
+
+ # Send it the HTTP response
+ client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n")
+
+ # We should get a successful response
+ r = self.successResultOf(d)
+ self.assertEqual(r.code, 200)
+
+ def test_client_headers_no_body(self):
+ """
+ If the HTTP request is connected, but gets no response before being
+ timed out, it'll give a ResponseNeverReceived.
+ """
+ d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+
+ self.pump()
+
+ conn = Mock()
+ clients = self.reactor.tcpClients
+ client = clients[0][2].buildProtocol(None)
+ client.makeConnection(conn)
+
+ # Deferred does not have a result
+ self.assertFalse(d.called)
+
+ # Send it the HTTP response
+ client.dataReceived(
+ (b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n"
+ b"Server: Fake\r\n\r\n")
+ )
+
+ # Push by enough to time it out
+ self.reactor.advance(10.5)
+ f = self.failureResultOf(d)
+
+ self.assertIsInstance(f.value, TimeoutError)
+
+ def test_client_sends_body(self):
+ self.cl.post_json(
+ "testserv:8008", "foo/bar", timeout=10000,
+ data={"a": "b"}
+ )
+
+ self.pump()
+
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ client = clients[0][2].buildProtocol(None)
+ server = HTTPChannel()
+
+ client.makeConnection(FakeTransport(server, self.reactor))
+ server.makeConnection(FakeTransport(client, self.reactor))
+
+ self.pump(0.1)
+
+ self.assertEqual(len(server.requests), 1)
+ request = server.requests[0]
+ content = request.content.read()
+ self.assertEqual(content, b'{"a":"b"}')
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 65df116efc..9e9fbbfe93 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -1,4 +1,5 @@
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,89 +12,62 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import tempfile
from mock import Mock, NonCallableMock
-from twisted.internet import defer, reactor
-from twisted.internet.defer import Deferred
-
from synapse.replication.tcp.client import (
ReplicationClientFactory,
ReplicationClientHandler,
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
-from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
from tests import unittest
-from tests.utils import setup_test_homeserver
-
-
-class TestReplicationClientHandler(ReplicationClientHandler):
- """Overrides on_rdata so that we can wait for it to happen"""
-
- def __init__(self, store):
- super(TestReplicationClientHandler, self).__init__(store)
- self._rdata_awaiters = []
-
- def await_replication(self):
- d = Deferred()
- self._rdata_awaiters.append(d)
- return make_deferred_yieldable(d)
+from tests.server import FakeTransport
- def on_rdata(self, stream_name, token, rows):
- awaiters = self._rdata_awaiters
- self._rdata_awaiters = []
- super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
- with PreserveLoggingContext():
- for a in awaiters:
- a.callback(None)
+class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
-class BaseSlavedStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield setup_test_homeserver(
- self.addCleanup,
+ hs = self.setup_test_homeserver(
"blue",
- http_client=None,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
- self.hs.get_ratelimiter().send_message.return_value = (True, 0)
+
+ hs.get_ratelimiter().send_message.return_value = (True, 0)
+
+ return hs
+
+ def prepare(self, reactor, clock, hs):
self.master_store = self.hs.get_datastore()
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
- # XXX: mktemp is unsafe and should never be used. but we're just a test.
- path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
- listener = reactor.listenUNIX(path, server_factory)
- self.addCleanup(listener.stopListening)
self.streamer = server_factory.streamer
- self.replication_handler = TestReplicationClientHandler(self.slaved_store)
+ self.replication_handler = ReplicationClientHandler(self.slaved_store)
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
- client_connector = reactor.connectUNIX(path, client_factory)
- self.addCleanup(client_factory.stopTrying)
- self.addCleanup(client_connector.disconnect)
+
+ server = server_factory.buildProtocol(None)
+ client = client_factory.buildProtocol(None)
+
+ client.makeConnection(FakeTransport(server, reactor))
+ server.makeConnection(FakeTransport(client, reactor))
def replicate(self):
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
- # xxx: should we be more specific in what we wait for?
- d = self.replication_handler.await_replication()
self.streamer.on_notifier_poke()
- return d
+ self.pump(0.1)
- @defer.inlineCallbacks
def check(self, method, args, expected_result=None):
- master_result = yield getattr(self.master_store, method)(*args)
- slaved_result = yield getattr(self.slaved_store, method)(*args)
+ master_result = self.get_success(getattr(self.master_store, method)(*args))
+ slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
if expected_result is not None:
self.assertEqual(master_result, expected_result)
self.assertEqual(slaved_result, expected_result)
diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py
index 87cc2b2fba..43e3248703 100644
--- a/tests/replication/slave/storage/test_account_data.py
+++ b/tests/replication/slave/storage/test_account_data.py
@@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from ._base import BaseSlavedStoreTestCase
@@ -27,16 +24,19 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedAccountDataStore
- @defer.inlineCallbacks
def test_user_account_data(self):
- yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
- yield self.replicate()
- yield self.check(
+ self.get_success(
+ self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
+ )
+ self.replicate()
+ self.check(
"get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1}
)
- yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
- yield self.replicate()
- yield self.check(
+ self.get_success(
+ self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
+ )
+ self.replicate()
+ self.check(
"get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2}
)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 622be2eef8..41be5d5a1a 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+from canonicaljson import encode_canonical_json
from synapse.events import FrozenEvent, _EventInternalMetadata
from synapse.events.snapshot import EventContext
@@ -28,7 +28,9 @@ ROOM_ID = "!room:blue"
def dict_equals(self, other):
- return self.__dict__ == other.__dict__
+ me = encode_canonical_json(self._event_dict)
+ them = encode_canonical_json(other._event_dict)
+ return me == them
def patch__eq__(cls):
@@ -55,69 +57,66 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def tearDown(self):
[unpatch() for unpatch in self.unpatches]
- @defer.inlineCallbacks
def test_get_latest_event_ids_in_room(self):
- create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
- yield self.replicate()
- yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
+ create = self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.replicate()
+ self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
- join = yield self.persist(
+ join = self.persist(
type="m.room.member",
key=USER_ID,
membership="join",
prev_events=[(create.event_id, {})],
)
- yield self.replicate()
- yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
+ self.replicate()
+ self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
- @defer.inlineCallbacks
def test_redactions(self):
- yield self.persist(type="m.room.create", key="", creator=USER_ID)
- yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
- msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
- yield self.replicate()
- yield self.check("get_event", [msg.event_id], msg)
+ msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
+ self.replicate()
+ self.check("get_event", [msg.event_id], msg)
- redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id)
- yield self.replicate()
+ redaction = self.persist(type="m.room.redaction", redacts=msg.event_id)
+ self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
- yield self.check("get_event", [msg.event_id], redacted)
+ self.check("get_event", [msg.event_id], redacted)
- @defer.inlineCallbacks
def test_backfilled_redactions(self):
- yield self.persist(type="m.room.create", key="", creator=USER_ID)
- yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
- msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
- yield self.replicate()
- yield self.check("get_event", [msg.event_id], msg)
+ msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
+ self.replicate()
+ self.check("get_event", [msg.event_id], msg)
- redaction = yield self.persist(
+ redaction = self.persist(
type="m.room.redaction", redacts=msg.event_id, backfill=True
)
- yield self.replicate()
+ self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
- yield self.check("get_event", [msg.event_id], redacted)
+ self.check("get_event", [msg.event_id], redacted)
- @defer.inlineCallbacks
def test_invites(self):
- yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
- event = yield self.persist(
- type="m.room.member", key=USER_ID_2, membership="invite"
- )
- yield self.replicate()
- yield self.check(
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.check("get_invited_rooms_for_user", [USER_ID_2], [])
+ event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
+
+ self.replicate()
+
+ self.check(
"get_invited_rooms_for_user",
[USER_ID_2],
[
@@ -131,37 +130,34 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
],
)
- @defer.inlineCallbacks
def test_push_actions_for_user(self):
- yield self.persist(type="m.room.create", creator=USER_ID)
- yield self.persist(type="m.room.join", key=USER_ID, membership="join")
- yield self.persist(
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.join", key=USER_ID, membership="join")
+ self.persist(
type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
)
- event1 = yield self.persist(
- type="m.room.message", msgtype="m.text", body="hello"
- )
- yield self.replicate()
- yield self.check(
+ event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
+ self.replicate()
+ self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "notify_count": 0},
)
- yield self.persist(
+ self.persist(
type="m.room.message",
msgtype="m.text",
body="world",
push_actions=[(USER_ID_2, ["notify"])],
)
- yield self.replicate()
- yield self.check(
+ self.replicate()
+ self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "notify_count": 1},
)
- yield self.persist(
+ self.persist(
type="m.room.message",
msgtype="m.text",
body="world",
@@ -169,8 +165,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
(USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}])
],
)
- yield self.replicate()
- yield self.check(
+ self.replicate()
+ self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 1, "notify_count": 2},
@@ -178,7 +174,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event_id = 0
- @defer.inlineCallbacks
def persist(
self,
sender=USER_ID,
@@ -205,8 +200,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
depth = self.event_id
if not prev_events:
- latest_event_ids = yield self.master_store.get_latest_event_ids_in_room(
- room_id
+ latest_event_ids = self.get_success(
+ self.master_store.get_latest_event_ids_in_room(room_id)
)
prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
@@ -239,19 +234,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)
else:
state_handler = self.hs.get_state_handler()
- context = yield state_handler.compute_event_context(event)
+ context = self.get_success(state_handler.compute_event_context(event))
- yield self.master_store.add_push_actions_to_staging(
+ self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}
)
ordering = None
if backfill:
- yield self.master_store.persist_events([(event, context)], backfilled=True)
+ self.get_success(
+ self.master_store.persist_events([(event, context)], backfilled=True)
+ )
else:
- ordering, _ = yield self.master_store.persist_event(event, context)
+ ordering, _ = self.get_success(
+ self.master_store.persist_event(event, context)
+ )
if ordering:
event.internal_metadata.stream_ordering = ordering
- defer.returnValue(event)
+ return event
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
index ae1adeded1..f47d94f690 100644
--- a/tests/replication/slave/storage/test_receipts.py
+++ b/tests/replication/slave/storage/test_receipts.py
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from ._base import BaseSlavedStoreTestCase
@@ -27,13 +25,10 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedReceiptsStore
- @defer.inlineCallbacks
def test_receipt(self):
- yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
- yield self.master_store.insert_receipt(
- ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}
- )
- yield self.replicate()
- yield self.check(
- "get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}
+ self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
+ self.get_success(
+ self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {})
)
+ self.replicate()
+ self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID})
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 9fe0760496..a824be9a62 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -22,39 +22,24 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer
-import synapse.rest.client.v1.room
from synapse.api.constants import Membership
-from synapse.http.server import JsonResource
-from synapse.types import UserID
-from synapse.util import Clock
+from synapse.rest.client.v1 import admin, login, room
from tests import unittest
-from tests.server import (
- ThreadedMemoryReactorClock,
- make_request,
- render,
- setup_test_homeserver,
-)
-
-from .utils import RestHelper
PATH_PREFIX = b"/_matrix/client/api/v1"
-class RoomBase(unittest.TestCase):
+class RoomBase(unittest.HomeserverTestCase):
rmcreator_id = None
- def setUp(self):
+ servlets = [room.register_servlets, room.register_deprecated_servlets]
- self.clock = ThreadedMemoryReactorClock()
- self.hs_clock = Clock(self.clock)
+ def make_homeserver(self, reactor, clock):
- self.hs = setup_test_homeserver(
- self.addCleanup,
+ self.hs = self.setup_test_homeserver(
"red",
http_client=None,
- clock=self.hs_clock,
- reactor=self.clock,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
@@ -63,42 +48,21 @@ class RoomBase(unittest.TestCase):
self.hs.get_federation_handler = Mock(return_value=Mock())
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.helper.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
-
- def get_user_by_req(request, allow_guest=False, rights="access"):
- return synapse.types.create_requester(
- UserID.from_string(self.helper.auth_user_id), 1, False, None
- )
-
- self.hs.get_auth().get_user_by_req = get_user_by_req
- self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
- self.hs.get_auth().get_access_token_from_request = Mock(return_value=b"1234")
-
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
self.hs.get_datastore().insert_client_ip = _insert_client_ip
- self.resource = JsonResource(self.hs)
- synapse.rest.client.v1.room.register_servlets(self.hs, self.resource)
- synapse.rest.client.v1.room.register_deprecated_servlets(self.hs, self.resource)
- self.helper = RestHelper(self.hs, self.resource, self.user_id)
+ return self.hs
class RoomPermissionsTestCase(RoomBase):
""" Tests room permissions. """
- user_id = b"@sid1:red"
- rmcreator_id = b"@notme:red"
-
- def setUp(self):
+ user_id = "@sid1:red"
+ rmcreator_id = "@notme:red"
- super(RoomPermissionsTestCase, self).setUp()
+ def prepare(self, reactor, clock, hs):
self.helper.auth_user_id = self.rmcreator_id
# create some rooms under the name rmcreator_id
@@ -114,22 +78,20 @@ class RoomPermissionsTestCase(RoomBase):
self.created_rmid_msg_path = (
"rooms/%s/send/m.room.message/a1" % (self.created_rmid)
).encode('ascii')
- request, channel = make_request(
- b"PUT",
- self.created_rmid_msg_path,
- b'{"msgtype":"m.text","body":"test msg"}',
+ request, channel = self.make_request(
+ "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
)
- render(request, self.resource, self.clock)
- self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
# set topic for public room
- request, channel = make_request(
- b"PUT",
+ request, channel = self.make_request(
+ "PUT",
("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'),
b'{"topic":"Public Room Topic"}',
)
- render(request, self.resource, self.clock)
- self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
# auth as user_id now
self.helper.auth_user_id = self.user_id
@@ -140,128 +102,128 @@ class RoomPermissionsTestCase(RoomBase):
seq = iter(range(100))
def send_msg_path():
- return b"/rooms/%s/send/m.room.message/mid%s" % (
+ return "/rooms/%s/send/m.room.message/mid%s" % (
self.created_rmid,
- str(next(seq)).encode('ascii'),
+ str(next(seq)),
)
# send message in uncreated room, expect 403
- request, channel = make_request(
- b"PUT",
- b"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
+ request, channel = self.make_request(
+ "PUT",
+ "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
msg_content,
)
- render(request, self.resource, self.clock)
- self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.render(request)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room not joined (no state), expect 403
- request, channel = make_request(b"PUT", send_msg_path(), msg_content)
- render(request, self.resource, self.clock)
- self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+ self.render(request)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
- request, channel = make_request(b"PUT", send_msg_path(), msg_content)
- render(request, self.resource, self.clock)
- self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+ self.render(request)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
# send message in created room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id)
- request, channel = make_request(b"PUT", send_msg_path(), msg_content)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
# send message in created room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
- request, channel = make_request(b"PUT", send_msg_path(), msg_content)
- render(request, self.resource, self.clock)
- self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+ self.render(request)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_topic_perms(self):
topic_content = b'{"topic":"My Topic Name"}'
- topic_path = b"/rooms/%s/state/m.room.topic" % self.created_rmid
+ topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
# set/get topic in uncreated room, expect 403
- request, channel = make_request(
- b"PUT", b"/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
+ request, channel = self.make_request(
+ "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
)
- render(request, self.resource, self.clock)
- self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
- request, channel = make_request(
- b"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
+ self.render(request)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
+ request, channel = self.make_request(
+ "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
)
- render(request, self.resource, self.clock)
- self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.render(request)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room not joined, expect 403
- request, channel = make_request(b"PUT", topic_path, topic_content)
- render(request, self.resource, self.clock)
- self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
- request, channel = make_request(b"GET", topic_path)
- render(request, self.resource, self.clock)
- self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", topic_path, topic_content)
+ self.render(request)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
+ request, channel = self.make_request("GET", topic_path)
+ self.render(request)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
# set topic in created PRIVATE room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
- request, channel = make_request(b"PUT", topic_path, topic_content)
- render(request, self.resource, self.clock)
- self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", topic_path, topic_content)
+ self.render(request)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
# get topic in created PRIVATE room and invited, expect 403
- request, channel = make_request(b"GET", topic_path)
- render(request, self.resource, self.clock)
- self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("GET", topic_path)
+ self.render(request)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id)
# Only room ops can set topic by default
self.helper.auth_user_id = self.rmcreator_id
- request, channel = make_request(b"PUT", topic_path, topic_content)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", topic_path, topic_content)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
self.helper.auth_user_id = self.user_id
- request, channel = make_request(b"GET", topic_path)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assert_dict(json.loads(topic_content), channel.json_body)
+ request, channel = self.make_request("GET", topic_path)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
+ self.assert_dict(json.loads(topic_content.decode('utf8')), channel.json_body)
# set/get topic in created PRIVATE room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
- request, channel = make_request(b"PUT", topic_path, topic_content)
- render(request, self.resource, self.clock)
- self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
- request, channel = make_request(b"GET", topic_path)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", topic_path, topic_content)
+ self.render(request)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
+ request, channel = self.make_request("GET", topic_path)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
# get topic in PUBLIC room, not joined, expect 403
- request, channel = make_request(
- b"GET", b"/rooms/%s/state/m.room.topic" % self.created_public_rmid
+ request, channel = self.make_request(
+ "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid
)
- render(request, self.resource, self.clock)
- self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.render(request)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
# set topic in PUBLIC room, not joined, expect 403
- request, channel = make_request(
- b"PUT",
- b"/rooms/%s/state/m.room.topic" % self.created_public_rmid,
+ request, channel = self.make_request(
+ "PUT",
+ "/rooms/%s/state/m.room.topic" % self.created_public_rmid,
topic_content,
)
- render(request, self.resource, self.clock)
- self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.render(request)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
def _test_get_membership(self, room=None, members=[], expect_code=None):
for member in members:
- path = b"/rooms/%s/state/m.room.member/%s" % (room, member)
- request, channel = make_request(b"GET", path)
- render(request, self.resource, self.clock)
- self.assertEquals(expect_code, int(channel.result["code"]))
+ path = "/rooms/%s/state/m.room.member/%s" % (room, member)
+ request, channel = self.make_request("GET", path)
+ self.render(request)
+ self.assertEquals(expect_code, channel.code)
def test_membership_basic_room_perms(self):
# === room does not exist ===
@@ -428,217 +390,211 @@ class RoomPermissionsTestCase(RoomBase):
class RoomsMemberListTestCase(RoomBase):
""" Tests /rooms/$room_id/members/list REST events."""
- user_id = b"@sid1:red"
+ user_id = "@sid1:red"
def test_get_member_list(self):
room_id = self.helper.create_room_as(self.user_id)
- request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
def test_get_member_list_no_room(self):
- request, channel = make_request(b"GET", b"/rooms/roomdoesnotexist/members")
- render(request, self.resource, self.clock)
- self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
+ self.render(request)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission(self):
- room_id = self.helper.create_room_as(b"@some_other_guy:red")
- request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id)
- render(request, self.resource, self.clock)
- self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ room_id = self.helper.create_room_as("@some_other_guy:red")
+ request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
+ self.render(request)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
def test_get_member_list_mixed_memberships(self):
- room_creator = b"@some_other_guy:red"
+ room_creator = "@some_other_guy:red"
room_id = self.helper.create_room_as(room_creator)
- room_path = b"/rooms/%s/members" % room_id
+ room_path = "/rooms/%s/members" % room_id
self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
# can't see list if you're just invited.
- request, channel = make_request(b"GET", room_path)
- render(request, self.resource, self.clock)
- self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("GET", room_path)
+ self.render(request)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
self.helper.join(room=room_id, user=self.user_id)
# can see list now joined
- request, channel = make_request(b"GET", room_path)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("GET", room_path)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
self.helper.leave(room=room_id, user=self.user_id)
# can see old list once left
- request, channel = make_request(b"GET", room_path)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("GET", room_path)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
class RoomsCreateTestCase(RoomBase):
""" Tests /rooms and /rooms/$room_id REST events. """
- user_id = b"@sid1:red"
+ user_id = "@sid1:red"
def test_post_room_no_keys(self):
# POST with no config keys, expect new room id
- request, channel = make_request(b"POST", b"/createRoom", b"{}")
+ request, channel = self.make_request("POST", "/createRoom", "{}")
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), channel.result)
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_visibility_key(self):
# POST with visibility config key, expect new room id
- request, channel = make_request(
- b"POST", b"/createRoom", b'{"visibility":"private"}'
+ request, channel = self.make_request(
+ "POST", "/createRoom", b'{"visibility":"private"}'
)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]))
+ self.render(request)
+ self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_custom_key(self):
# POST with custom config keys, expect new room id
- request, channel = make_request(b"POST", b"/createRoom", b'{"custom":"stuff"}')
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]))
+ request, channel = self.make_request(
+ "POST", "/createRoom", b'{"custom":"stuff"}'
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_known_and_unknown_keys(self):
# POST with custom + known config keys, expect new room id
- request, channel = make_request(
- b"POST", b"/createRoom", b'{"visibility":"private","custom":"things"}'
+ request, channel = self.make_request(
+ "POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]))
+ self.render(request)
+ self.assertEquals(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_invalid_content(self):
# POST with invalid content / paths, expect 400
- request, channel = make_request(b"POST", b"/createRoom", b'{"visibili')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]))
+ request, channel = self.make_request("POST", "/createRoom", b'{"visibili')
+ self.render(request)
+ self.assertEquals(400, channel.code)
- request, channel = make_request(b"POST", b"/createRoom", b'["hello"]')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]))
+ request, channel = self.make_request("POST", "/createRoom", b'["hello"]')
+ self.render(request)
+ self.assertEquals(400, channel.code)
class RoomTopicTestCase(RoomBase):
""" Tests /rooms/$room_id/topic REST events. """
- user_id = b"@sid1:red"
-
- def setUp(self):
-
- super(RoomTopicTestCase, self).setUp()
+ user_id = "@sid1:red"
+ def prepare(self, reactor, clock, hs):
# create the room
self.room_id = self.helper.create_room_as(self.user_id)
- self.path = b"/rooms/%s/state/m.room.topic" % (self.room_id,)
+ self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,)
def test_invalid_puts(self):
# missing keys or invalid json
- request, channel = make_request(b"PUT", self.path, '{}')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", self.path, '{}')
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = make_request(b"PUT", self.path, '{"_name":"bob"}')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = make_request(b"PUT", self.path, '{"nao')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", self.path, '{"nao')
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = make_request(
- b"PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]'
+ request, channel = self.make_request(
+ "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]'
)
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = make_request(b"PUT", self.path, 'text only')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", self.path, 'text only')
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = make_request(b"PUT", self.path, '')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", self.path, '')
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
# valid key, wrong type
content = '{"topic":["Topic name"]}'
- request, channel = make_request(b"PUT", self.path, content)
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", self.path, content)
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_topic(self):
# nothing should be there
- request, channel = make_request(b"GET", self.path)
- render(request, self.resource, self.clock)
- self.assertEquals(404, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("GET", self.path)
+ self.render(request)
+ self.assertEquals(404, channel.code, msg=channel.result["body"])
# valid put
content = '{"topic":"Topic name"}'
- request, channel = make_request(b"PUT", self.path, content)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", self.path, content)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
# valid get
- request, channel = make_request(b"GET", self.path)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("GET", self.path)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
def test_rooms_topic_with_extra_keys(self):
# valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}'
- request, channel = make_request(b"PUT", self.path, content)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", self.path, content)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
# valid get
- request, channel = make_request(b"GET", self.path)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("GET", self.path)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
class RoomMemberStateTestCase(RoomBase):
""" Tests /rooms/$room_id/members/$user_id/state REST events. """
- user_id = b"@sid1:red"
-
- def setUp(self):
+ user_id = "@sid1:red"
- super(RoomMemberStateTestCase, self).setUp()
+ def prepare(self, reactor, clock, hs):
self.room_id = self.helper.create_room_as(self.user_id)
- def tearDown(self):
- pass
-
def test_invalid_puts(self):
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json
- request, channel = make_request(b"PUT", path, '{}')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", path, '{}')
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = make_request(b"PUT", path, '{"_name":"bob"}')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", path, '{"_name":"bo"}')
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = make_request(b"PUT", path, '{"nao')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", path, '{"nao')
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = make_request(
- b"PUT", path, b'[{"_name":"bob"},{"_name":"jill"}]'
+ request, channel = self.make_request(
+ "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
)
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = make_request(b"PUT", path, 'text only')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", path, 'text only')
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = make_request(b"PUT", path, '')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", path, '')
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
# valid keys, wrong types
content = '{"membership":["%s","%s","%s"]}' % (
@@ -646,9 +602,9 @@ class RoomMemberStateTestCase(RoomBase):
Membership.JOIN,
Membership.LEAVE,
)
- request, channel = make_request(b"PUT", path, content.encode('ascii'))
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", path, content.encode('ascii'))
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_members_self(self):
path = "/rooms/%s/state/m.room.member/%s" % (
@@ -658,13 +614,13 @@ class RoomMemberStateTestCase(RoomBase):
# valid join message (NOOP since we made the room)
content = '{"membership":"%s"}' % Membership.JOIN
- request, channel = make_request(b"PUT", path, content.encode('ascii'))
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", path, content.encode('ascii'))
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
- request, channel = make_request(b"GET", path, None)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("GET", path, None)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
expected_response = {"membership": Membership.JOIN}
self.assertEquals(expected_response, channel.json_body)
@@ -678,13 +634,13 @@ class RoomMemberStateTestCase(RoomBase):
# valid invite message
content = '{"membership":"%s"}' % Membership.INVITE
- request, channel = make_request(b"PUT", path, content)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", path, content)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
- request, channel = make_request(b"GET", path, None)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("GET", path, None)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assertEquals(json.loads(content), channel.json_body)
def test_rooms_members_other_custom_keys(self):
@@ -699,13 +655,13 @@ class RoomMemberStateTestCase(RoomBase):
Membership.INVITE,
"Join us!",
)
- request, channel = make_request(b"PUT", path, content)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", path, content)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
- request, channel = make_request(b"GET", path, None)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("GET", path, None)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
self.assertEquals(json.loads(content), channel.json_body)
@@ -714,60 +670,58 @@ class RoomMessagesTestCase(RoomBase):
user_id = "@sid1:red"
- def setUp(self):
- super(RoomMessagesTestCase, self).setUp()
-
+ def prepare(self, reactor, clock, hs):
self.room_id = self.helper.create_room_as(self.user_id)
def test_invalid_puts(self):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json
- request, channel = make_request(b"PUT", path, '{}')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", path, b'{}')
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = make_request(b"PUT", path, '{"_name":"bob"}')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", path, b'{"_name":"bo"}')
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = make_request(b"PUT", path, '{"nao')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", path, b'{"nao')
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = make_request(
- b"PUT", path, '[{"_name":"bob"},{"_name":"jill"}]'
+ request, channel = self.make_request(
+ "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
)
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = make_request(b"PUT", path, 'text only')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", path, b'text only')
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
- request, channel = make_request(b"PUT", path, '')
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ request, channel = self.make_request("PUT", path, b'')
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
def test_rooms_messages_sent(self):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
- content = '{"body":"test","msgtype":{"type":"a"}}'
- request, channel = make_request(b"PUT", path, content)
- render(request, self.resource, self.clock)
- self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"])
+ content = b'{"body":"test","msgtype":{"type":"a"}}'
+ request, channel = self.make_request("PUT", path, content)
+ self.render(request)
+ self.assertEquals(400, channel.code, msg=channel.result["body"])
# custom message types
- content = '{"body":"test","msgtype":"test.custom.text"}'
- request, channel = make_request(b"PUT", path, content)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ content = b'{"body":"test","msgtype":"test.custom.text"}'
+ request, channel = self.make_request("PUT", path, content)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
# m.text message type
path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
- content = '{"body":"test2","msgtype":"m.text"}'
- request, channel = make_request(b"PUT", path, content)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ content = b'{"body":"test2","msgtype":"m.text"}'
+ request, channel = self.make_request("PUT", path, content)
+ self.render(request)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
class RoomInitialSyncTestCase(RoomBase):
@@ -775,16 +729,16 @@ class RoomInitialSyncTestCase(RoomBase):
user_id = "@sid1:red"
- def setUp(self):
- super(RoomInitialSyncTestCase, self).setUp()
-
+ def prepare(self, reactor, clock, hs):
# create the room
self.room_id = self.helper.create_room_as(self.user_id)
def test_initial_sync(self):
- request, channel = make_request(b"GET", "/rooms/%s/initialSync" % self.room_id)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]))
+ request, channel = self.make_request(
+ "GET", "/rooms/%s/initialSync" % self.room_id
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
self.assertEquals(self.room_id, channel.json_body["room_id"])
self.assertEquals("join", channel.json_body["membership"])
@@ -819,17 +773,16 @@ class RoomMessageListTestCase(RoomBase):
user_id = "@sid1:red"
- def setUp(self):
- super(RoomMessageListTestCase, self).setUp()
+ def prepare(self, reactor, clock, hs):
self.room_id = self.helper.create_room_as(self.user_id)
def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0_0_0"
- request, channel = make_request(
- b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
+ request, channel = self.make_request(
+ "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]))
+ self.render(request)
+ self.assertEquals(200, channel.code)
self.assertTrue("start" in channel.json_body)
self.assertEquals(token, channel.json_body['start'])
self.assertTrue("chunk" in channel.json_body)
@@ -837,12 +790,116 @@ class RoomMessageListTestCase(RoomBase):
def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0_0_0"
- request, channel = make_request(
- b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
+ request, channel = self.make_request(
+ "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
- render(request, self.resource, self.clock)
- self.assertEquals(200, int(channel.result["code"]))
+ self.render(request)
+ self.assertEquals(200, channel.code)
self.assertTrue("start" in channel.json_body)
self.assertEquals(token, channel.json_body['start'])
self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body)
+
+
+class RoomSearchTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+ user_id = True
+ hijack_auth = False
+
+ def prepare(self, reactor, clock, hs):
+
+ # Register the user who does the searching
+ self.user_id = self.register_user("user", "pass")
+ self.access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ # Create a room
+ self.room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+
+ # Invite the other person
+ self.helper.invite(
+ room=self.room,
+ src=self.user_id,
+ tok=self.access_token,
+ targ=self.other_user_id,
+ )
+
+ # The other user joins
+ self.helper.join(
+ room=self.room, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ def test_finds_message(self):
+ """
+ The search functionality will search for content in messages if asked to
+ do so.
+ """
+ # The other user sends some messages
+ self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
+ self.helper.send(self.room, body="There!", tok=self.other_access_token)
+
+ request, channel = self.make_request(
+ "POST",
+ "/search?access_token=%s" % (self.access_token,),
+ {
+ "search_categories": {
+ "room_events": {"keys": ["content.body"], "search_term": "Hi"}
+ }
+ },
+ )
+ self.render(request)
+
+ # Check we get the results we expect -- one search result, of the sent
+ # messages
+ self.assertEqual(channel.code, 200)
+ results = channel.json_body["search_categories"]["room_events"]
+ self.assertEqual(results["count"], 1)
+ self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!")
+
+ # No context was requested, so we should get none.
+ self.assertEqual(results["results"][0]["context"], {})
+
+ def test_include_context(self):
+ """
+ When event_context includes include_profile, profile information will be
+ included in the search response.
+ """
+ # The other user sends some messages
+ self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
+ self.helper.send(self.room, body="There!", tok=self.other_access_token)
+
+ request, channel = self.make_request(
+ "POST",
+ "/search?access_token=%s" % (self.access_token,),
+ {
+ "search_categories": {
+ "room_events": {
+ "keys": ["content.body"],
+ "search_term": "Hi",
+ "event_context": {"include_profile": True},
+ }
+ }
+ },
+ )
+ self.render(request)
+
+ # Check we get the results we expect -- one search result, of the sent
+ # messages
+ self.assertEqual(channel.code, 200)
+ results = channel.json_body["search_categories"]["room_events"]
+ self.assertEqual(results["count"], 1)
+ self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!")
+
+ # We should get context info, like the two users, and the display names.
+ context = results["results"][0]["context"]
+ self.assertEqual(len(context["profile_info"].keys()), 2)
+ self.assertEqual(
+ context["profile_info"][self.other_user_id]["displayname"], "otheruser"
+ )
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 40dc4ea256..530dc8ba6d 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -240,7 +240,6 @@ class RestHelper(object):
self.assertEquals(200, code)
defer.returnValue(response)
- @defer.inlineCallbacks
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -248,9 +247,16 @@ class RestHelper(object):
body = "body_text_here"
path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
- content = '{"msgtype":"m.text","body":"%s"}' % body
+ content = {"msgtype": "m.text", "body": body}
if tok:
path = path + "?access_token=%s" % tok
- (code, response) = yield self.mock_resource.trigger("PUT", path, content)
- self.assertEquals(expect_code, code, msg=str(response))
+ request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
+ render(request, self.resource, self.hs.get_reactor())
+
+ assert int(channel.result["code"]) == expect_code, (
+ "Expected: %d, got: %d, resp: %r"
+ % (expect_code, int(channel.result["code"]), channel.result["body"])
+ )
+
+ return channel.json_body
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 560b1fba96..4c30c5f258 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -62,12 +62,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertTrue(
set(
- [
- "next_batch",
- "rooms",
- "account_data",
- "to_device",
- "device_lists",
- ]
+ ["next_batch", "rooms", "account_data", "to_device", "device_lists"]
).issubset(set(channel.json_body.keys()))
)
diff --git a/tests/scripts/__init__.py b/tests/scripts/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/scripts/__init__.py
diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py
new file mode 100644
index 0000000000..6f56893f5e
--- /dev/null
+++ b/tests/scripts/test_new_matrix_user.py
@@ -0,0 +1,160 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from synapse._scripts.register_new_matrix_user import request_registration
+
+from tests.unittest import TestCase
+
+
+class RegisterTestCase(TestCase):
+ def test_success(self):
+ """
+ The script will fetch a nonce, and then generate a MAC with it, and then
+ post that MAC.
+ """
+
+ def get(url, verify=None):
+ r = Mock()
+ r.status_code = 200
+ r.json = lambda: {"nonce": "a"}
+ return r
+
+ def post(url, json=None, verify=None):
+ # Make sure we are sent the correct info
+ self.assertEqual(json["username"], "user")
+ self.assertEqual(json["password"], "pass")
+ self.assertEqual(json["nonce"], "a")
+ # We want a 40-char hex MAC
+ self.assertEqual(len(json["mac"]), 40)
+
+ r = Mock()
+ r.status_code = 200
+ return r
+
+ requests = Mock()
+ requests.get = get
+ requests.post = post
+
+ # The fake stdout will be written here
+ out = []
+ err_code = []
+
+ request_registration(
+ "user",
+ "pass",
+ "matrix.org",
+ "shared",
+ admin=False,
+ requests=requests,
+ _print=out.append,
+ exit=err_code.append,
+ )
+
+ # We should get the success message making sure everything is OK.
+ self.assertIn("Success!", out)
+
+ # sys.exit shouldn't have been called.
+ self.assertEqual(err_code, [])
+
+ def test_failure_nonce(self):
+ """
+ If the script fails to fetch a nonce, it throws an error and quits.
+ """
+
+ def get(url, verify=None):
+ r = Mock()
+ r.status_code = 404
+ r.reason = "Not Found"
+ r.json = lambda: {"not": "error"}
+ return r
+
+ requests = Mock()
+ requests.get = get
+
+ # The fake stdout will be written here
+ out = []
+ err_code = []
+
+ request_registration(
+ "user",
+ "pass",
+ "matrix.org",
+ "shared",
+ admin=False,
+ requests=requests,
+ _print=out.append,
+ exit=err_code.append,
+ )
+
+ # Exit was called
+ self.assertEqual(err_code, [1])
+
+ # We got an error message
+ self.assertIn("ERROR! Received 404 Not Found", out)
+ self.assertNotIn("Success!", out)
+
+ def test_failure_post(self):
+ """
+ The script will fetch a nonce, and then if the final POST fails, will
+ report an error and quit.
+ """
+
+ def get(url, verify=None):
+ r = Mock()
+ r.status_code = 200
+ r.json = lambda: {"nonce": "a"}
+ return r
+
+ def post(url, json=None, verify=None):
+ # Make sure we are sent the correct info
+ self.assertEqual(json["username"], "user")
+ self.assertEqual(json["password"], "pass")
+ self.assertEqual(json["nonce"], "a")
+ # We want a 40-char hex MAC
+ self.assertEqual(len(json["mac"]), 40)
+
+ r = Mock()
+ # Then 500 because we're jerks
+ r.status_code = 500
+ r.reason = "Broken"
+ return r
+
+ requests = Mock()
+ requests.get = get
+ requests.post = post
+
+ # The fake stdout will be written here
+ out = []
+ err_code = []
+
+ request_registration(
+ "user",
+ "pass",
+ "matrix.org",
+ "shared",
+ admin=False,
+ requests=requests,
+ _print=out.append,
+ exit=err_code.append,
+ )
+
+ # Exit was called
+ self.assertEqual(err_code, [1])
+
+ # We got an error message
+ self.assertIn("ERROR! Received 500 Broken", out)
+ self.assertNotIn("Success!", out)
diff --git a/tests/server.py b/tests/server.py
index c63b2c3100..7bee58dff1 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -4,9 +4,14 @@ from io import BytesIO
from six import text_type
import attr
+from zope.interface import implementer
-from twisted.internet import threads
+from twisted.internet import address, threads, udp
+from twisted.internet._resolver import HostResolution
+from twisted.internet.address import IPv4Address
from twisted.internet.defer import Deferred
+from twisted.internet.error import DNSLookupError
+from twisted.internet.interfaces import IReactorPluggableNameResolver
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactorClock
@@ -63,7 +68,9 @@ class FakeChannel(object):
self.result["done"] = True
def getPeer(self):
- return None
+ # We give an address so that getClientIP returns a non null entry,
+ # causing us to record the MAU
+ return address.IPv4Address("TCP", "127.0.0.1", 3423)
def getHost(self):
return None
@@ -91,7 +98,7 @@ class FakeSite:
return FakeLogger()
-def make_request(method, path, content=b""):
+def make_request(method, path, content=b"", access_token=None, request=SynapseRequest):
"""
Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath.
@@ -113,9 +120,16 @@ def make_request(method, path, content=b""):
site = FakeSite()
channel = FakeChannel()
- req = SynapseRequest(site, channel)
+ req = request(site, channel)
req.process = lambda: b""
req.content = BytesIO(content)
+
+ if access_token:
+ req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token)
+
+ if content:
+ req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
+
req.requestReceived(method, path, b"1.1")
return req, channel
@@ -147,11 +161,46 @@ def render(request, resource, clock):
wait_until_result(clock, request)
+@implementer(IReactorPluggableNameResolver)
class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
A MemoryReactorClock that supports callFromThread.
"""
+ def __init__(self):
+ self._udp = []
+ self.lookups = {}
+
+ class Resolver(object):
+ def resolveHostName(
+ _self,
+ resolutionReceiver,
+ hostName,
+ portNumber=0,
+ addressTypes=None,
+ transportSemantics='TCP',
+ ):
+
+ resolution = HostResolution(hostName)
+ resolutionReceiver.resolutionBegan(resolution)
+ if hostName not in self.lookups:
+ raise DNSLookupError("OH NO")
+
+ resolutionReceiver.addressResolved(
+ IPv4Address('TCP', self.lookups[hostName], portNumber)
+ )
+ resolutionReceiver.resolutionComplete()
+ return resolution
+
+ self.nameResolver = Resolver()
+ super(ThreadedMemoryReactorClock, self).__init__()
+
+ def listenUDP(self, port, protocol, interface='', maxPacketSize=8196):
+ p = udp.Port(port, protocol, interface, maxPacketSize, self)
+ p.startListening()
+ self._udp.append(p)
+ return p
+
def callFromThread(self, callback, *args, **kwargs):
"""
Make the callback fire in the next reactor iteration.
@@ -225,6 +274,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
clock.threadpool = ThreadPool()
pool.threadpool = ThreadPool()
+ pool.running = True
return d
@@ -232,3 +282,84 @@ def get_clock():
clock = ThreadedMemoryReactorClock()
hs_clock = Clock(clock)
return (clock, hs_clock)
+
+
+@attr.s
+class FakeTransport(object):
+ """
+ A twisted.internet.interfaces.ITransport implementation which sends all its data
+ straight into an IProtocol object: it exists to connect two IProtocols together.
+
+ To use it, instantiate it with the receiving IProtocol, and then pass it to the
+ sending IProtocol's makeConnection method:
+
+ server = HTTPChannel()
+ client.makeConnection(FakeTransport(server, self.reactor))
+
+ If you want bidirectional communication, you'll need two instances.
+ """
+
+ other = attr.ib()
+ """The Protocol object which will receive any data written to this transport.
+
+ :type: twisted.internet.interfaces.IProtocol
+ """
+
+ _reactor = attr.ib()
+ """Test reactor
+
+ :type: twisted.internet.interfaces.IReactorTime
+ """
+
+ disconnecting = False
+ buffer = attr.ib(default=b'')
+ producer = attr.ib(default=None)
+
+ def getPeer(self):
+ return None
+
+ def getHost(self):
+ return None
+
+ def loseConnection(self):
+ self.disconnecting = True
+
+ def abortConnection(self):
+ self.disconnecting = True
+
+ def pauseProducing(self):
+ self.producer.pauseProducing()
+
+ def unregisterProducer(self):
+ if not self.producer:
+ return
+
+ self.producer = None
+
+ def registerProducer(self, producer, streaming):
+ self.producer = producer
+ self.producerStreaming = streaming
+
+ def _produce():
+ d = self.producer.resumeProducing()
+ d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
+
+ if not streaming:
+ self._reactor.callLater(0.0, _produce)
+
+ def write(self, byt):
+ self.buffer = self.buffer + byt
+
+ def _write():
+ if getattr(self.other, "transport") is not None:
+ self.other.dataReceived(self.buffer)
+ self.buffer = b""
+ return
+
+ self._reactor.callLater(0.0, _write)
+
+ _write()
+
+ def writeSequence(self, seq):
+ for x in seq:
+ self.write(x)
diff --git a/tests/server_notices/__init__.py b/tests/server_notices/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/server_notices/__init__.py
diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
new file mode 100644
index 0000000000..95badc985e
--- /dev/null
+++ b/tests/server_notices/test_consent.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest.client.v1 import admin, login, room
+from synapse.rest.client.v2_alpha import sync
+
+from tests import unittest
+
+
+class ConsentNoticesTests(unittest.HomeserverTestCase):
+
+ servlets = [
+ sync.register_servlets,
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+
+ self.consent_notice_message = "consent %(consent_uri)s"
+ config = self.default_config()
+ config.user_consent_version = "1"
+ config.user_consent_server_notice_content = {
+ "msgtype": "m.text",
+ "body": self.consent_notice_message,
+ }
+ config.public_baseurl = "https://example.com/"
+ config.form_secret = "123abc"
+
+ config.server_notices_mxid = "@notices:test"
+ config.server_notices_mxid_display_name = "test display name"
+ config.server_notices_mxid_avatar_url = None
+ config.server_notices_room_name = "Server Notices"
+
+ hs = self.setup_test_homeserver(config=config)
+
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("bob", "abc123")
+ self.access_token = self.login("bob", "abc123")
+
+ def test_get_sync_message(self):
+ """
+ When user consent server notices are enabled, a sync will cause a notice
+ to fire (in a room which the user is invited to). The notice contains
+ the notice URL + an authentication code.
+ """
+ # Initial sync, to get the user consent room invite
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/sync", access_token=self.access_token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # Get the Room ID to join
+ room_id = list(channel.json_body["rooms"]["invite"].keys())[0]
+
+ # Join the room
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/" + room_id + "/join",
+ access_token=self.access_token,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # Sync again, to get the message in the room
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/sync", access_token=self.access_token
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ # Get the message
+ room = channel.json_body["rooms"]["join"][room_id]
+ messages = [
+ x for x in room["timeline"]["events"] if x["type"] == "m.room.message"
+ ]
+
+ # One message, with the consent URL
+ self.assertEqual(len(messages), 1)
+ self.assertTrue(
+ messages[0]["content"]["body"].startswith(
+ "consent https://example.com/_matrix/consent"
+ )
+ )
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
new file mode 100644
index 0000000000..4701eedd45
--- /dev/null
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -0,0 +1,207 @@
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, ServerNoticeMsgType
+from synapse.api.errors import ResourceLimitError
+from synapse.handlers.auth import AuthHandler
+from synapse.server_notices.resource_limits_server_notices import (
+ ResourceLimitsServerNotices,
+)
+
+from tests import unittest
+from tests.utils import setup_test_homeserver
+
+
+class AuthHandlers(object):
+ def __init__(self, hs):
+ self.auth_handler = AuthHandler(hs)
+
+
+class TestResourceLimitsServerNotices(unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
+ self.hs.handlers = AuthHandlers(self.hs)
+ self.auth_handler = self.hs.handlers.auth_handler
+ self.server_notices_sender = self.hs.get_server_notices_sender()
+
+ # relying on [1] is far from ideal, but the only case where
+ # ResourceLimitsServerNotices class needs to be isolated is this test,
+ # general code should never have a reason to do so ...
+ self._rlsn = self.server_notices_sender._server_notices[1]
+ if not isinstance(self._rlsn, ResourceLimitsServerNotices):
+ raise Exception("Failed to find reference to ResourceLimitsServerNotices")
+
+ self._rlsn._store.user_last_seen_monthly_active = Mock(
+ return_value=defer.succeed(1000)
+ )
+ self._send_notice = self._rlsn._server_notices_manager.send_notice
+ self._rlsn._server_notices_manager.send_notice = Mock()
+ self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None))
+ self._rlsn._store.get_events = Mock(return_value=defer.succeed({}))
+
+ self._send_notice = self._rlsn._server_notices_manager.send_notice
+
+ self.hs.config.limit_usage_by_mau = True
+ self.user_id = "@user_id:test"
+
+ # self.server_notices_mxid = "@server:test"
+ # self.server_notices_mxid_display_name = None
+ # self.server_notices_mxid_avatar_url = None
+ # self.server_notices_room_name = "Server Notices"
+
+ self._rlsn._server_notices_manager.get_notice_room_for_user = Mock(
+ returnValue=""
+ )
+ self._rlsn._store.add_tag_to_room = Mock()
+ self._rlsn._store.get_tags_for_room = Mock(return_value={})
+ self.hs.config.admin_contact = "mailto:user@test.com"
+
+ @defer.inlineCallbacks
+ def test_maybe_send_server_notice_to_user_flag_off(self):
+ """Tests cases where the flags indicate nothing to do"""
+ # test hs disabled case
+ self.hs.config.hs_disabled = True
+
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+
+ self._send_notice.assert_not_called()
+ # Test when mau limiting disabled
+ self.hs.config.hs_disabled = False
+ self.hs.limit_usage_by_mau = False
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+
+ self._send_notice.assert_not_called()
+
+ @defer.inlineCallbacks
+ def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
+ """Test when user has blocked notice, but should have it removed"""
+
+ self._rlsn._auth.check_auth_blocking = Mock()
+ mock_event = Mock(
+ type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
+ )
+ self._rlsn._store.get_events = Mock(
+ return_value=defer.succeed({"123": mock_event})
+ )
+
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+ # Would be better to check the content, but once == remove blocking event
+ self._send_notice.assert_called_once()
+
+ @defer.inlineCallbacks
+ def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self):
+ """Test when user has blocked notice, but notice ought to be there (NOOP)"""
+ self._rlsn._auth.check_auth_blocking = Mock(
+ side_effect=ResourceLimitError(403, 'foo')
+ )
+
+ mock_event = Mock(
+ type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
+ )
+ self._rlsn._store.get_events = Mock(
+ return_value=defer.succeed({"123": mock_event})
+ )
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+
+ self._send_notice.assert_not_called()
+
+ @defer.inlineCallbacks
+ def test_maybe_send_server_notice_to_user_add_blocked_notice(self):
+ """Test when user does not have blocked notice, but should have one"""
+
+ self._rlsn._auth.check_auth_blocking = Mock(
+ side_effect=ResourceLimitError(403, 'foo')
+ )
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+
+ # Would be better to check contents, but 2 calls == set blocking event
+ self.assertTrue(self._send_notice.call_count == 2)
+
+ @defer.inlineCallbacks
+ def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self):
+ """Test when user does not have blocked notice, nor should they (NOOP)"""
+
+ self._rlsn._auth.check_auth_blocking = Mock()
+
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+
+ self._send_notice.assert_not_called()
+
+ @defer.inlineCallbacks
+ def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self):
+
+ """Test when user is not part of the MAU cohort - this should not ever
+ happen - but ...
+ """
+
+ self._rlsn._auth.check_auth_blocking = Mock()
+ self._rlsn._store.user_last_seen_monthly_active = Mock(
+ return_value=defer.succeed(None)
+ )
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+
+ self._send_notice.assert_not_called()
+
+
+class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield setup_test_homeserver(self.addCleanup)
+ self.store = self.hs.get_datastore()
+ self.server_notices_sender = self.hs.get_server_notices_sender()
+ self.server_notices_manager = self.hs.get_server_notices_manager()
+ self.event_source = self.hs.get_event_sources()
+
+ # relying on [1] is far from ideal, but the only case where
+ # ResourceLimitsServerNotices class needs to be isolated is this test,
+ # general code should never have a reason to do so ...
+ self._rlsn = self.server_notices_sender._server_notices[1]
+ if not isinstance(self._rlsn, ResourceLimitsServerNotices):
+ raise Exception("Failed to find reference to ResourceLimitsServerNotices")
+
+ self.hs.config.limit_usage_by_mau = True
+ self.hs.config.hs_disabled = False
+ self.hs.config.max_mau_value = 5
+ self.hs.config.server_notices_mxid = "@server:test"
+ self.hs.config.server_notices_mxid_display_name = None
+ self.hs.config.server_notices_mxid_avatar_url = None
+ self.hs.config.server_notices_room_name = "Test Server Notice Room"
+
+ self.user_id = "@user_id:test"
+
+ self.hs.config.admin_contact = "mailto:user@test.com"
+
+ @defer.inlineCallbacks
+ def test_server_notice_only_sent_once(self):
+ self.store.get_monthly_active_count = Mock(return_value=1000)
+
+ self.store.user_last_seen_monthly_active = Mock(return_value=1000)
+
+ # Call the function multiple times to ensure we only send the notice once
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+ yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
+
+ # Now lets get the last load of messages in the service notice room and
+ # check that there is only one server notice
+ room_id = yield self.server_notices_manager.get_notice_room_for_user(
+ self.user_id
+ )
+
+ token = yield self.event_source.get_current_token()
+ events, _ = yield self.store.get_recent_events_for_room(
+ room_id, limit=100, end_token=token.room_key
+ )
+
+ count = 0
+ for event in events:
+ if event.type != EventTypes.Message:
+ continue
+ if event.content.get("msgtype") != ServerNoticeMsgType:
+ continue
+
+ count += 1
+
+ self.assertEqual(count, 1)
diff --git a/tests/state/__init__.py b/tests/state/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/state/__init__.py
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
new file mode 100644
index 0000000000..efd85ebe6c
--- /dev/null
+++ b/tests/state/test_v2.py
@@ -0,0 +1,663 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+
+from six.moves import zip
+
+import attr
+
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.event_auth import auth_types_for_event
+from synapse.events import FrozenEvent
+from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
+from synapse.types import EventID
+
+from tests import unittest
+
+ALICE = "@alice:example.com"
+BOB = "@bob:example.com"
+CHARLIE = "@charlie:example.com"
+EVELYN = "@evelyn:example.com"
+ZARA = "@zara:example.com"
+
+ROOM_ID = "!test:example.com"
+
+MEMBERSHIP_CONTENT_JOIN = {"membership": Membership.JOIN}
+MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN}
+
+
+ORIGIN_SERVER_TS = 0
+
+
+class FakeEvent(object):
+ """A fake event we use as a convenience.
+
+ NOTE: Again as a convenience we use "node_ids" rather than event_ids to
+ refer to events. The event_id has node_id as localpart and example.com
+ as domain.
+ """
+ def __init__(self, id, sender, type, state_key, content):
+ self.node_id = id
+ self.event_id = EventID(id, "example.com").to_string()
+ self.sender = sender
+ self.type = type
+ self.state_key = state_key
+ self.content = content
+
+ def to_event(self, auth_events, prev_events):
+ """Given the auth_events and prev_events, convert to a Frozen Event
+
+ Args:
+ auth_events (list[str]): list of event_ids
+ prev_events (list[str]): list of event_ids
+
+ Returns:
+ FrozenEvent
+ """
+ global ORIGIN_SERVER_TS
+
+ ts = ORIGIN_SERVER_TS
+ ORIGIN_SERVER_TS = ORIGIN_SERVER_TS + 1
+
+ event_dict = {
+ "auth_events": [(a, {}) for a in auth_events],
+ "prev_events": [(p, {}) for p in prev_events],
+ "event_id": self.node_id,
+ "sender": self.sender,
+ "type": self.type,
+ "content": self.content,
+ "origin_server_ts": ts,
+ "room_id": ROOM_ID,
+ }
+
+ if self.state_key is not None:
+ event_dict["state_key"] = self.state_key
+
+ return FrozenEvent(event_dict)
+
+
+# All graphs start with this set of events
+INITIAL_EVENTS = [
+ FakeEvent(
+ id="CREATE",
+ sender=ALICE,
+ type=EventTypes.Create,
+ state_key="",
+ content={"creator": ALICE},
+ ),
+ FakeEvent(
+ id="IMA",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=ALICE,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ),
+ FakeEvent(
+ id="IPOWER",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={"users": {ALICE: 100}},
+ ),
+ FakeEvent(
+ id="IJR",
+ sender=ALICE,
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.PUBLIC},
+ ),
+ FakeEvent(
+ id="IMB",
+ sender=BOB,
+ type=EventTypes.Member,
+ state_key=BOB,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ),
+ FakeEvent(
+ id="IMC",
+ sender=CHARLIE,
+ type=EventTypes.Member,
+ state_key=CHARLIE,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ),
+ FakeEvent(
+ id="IMZ",
+ sender=ZARA,
+ type=EventTypes.Member,
+ state_key=ZARA,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ),
+ FakeEvent(
+ id="START",
+ sender=ZARA,
+ type=EventTypes.Message,
+ state_key=None,
+ content={},
+ ),
+ FakeEvent(
+ id="END",
+ sender=ZARA,
+ type=EventTypes.Message,
+ state_key=None,
+ content={},
+ ),
+]
+
+INITIAL_EDGES = [
+ "START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE",
+]
+
+
+class StateTestCase(unittest.TestCase):
+ def test_ban_vs_pl(self):
+ events = [
+ FakeEvent(
+ id="PA",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ }
+ },
+ ),
+ FakeEvent(
+ id="MA",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=ALICE,
+ content={"membership": Membership.JOIN},
+ ),
+ FakeEvent(
+ id="MB",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=BOB,
+ content={"membership": Membership.BAN},
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ ]
+
+ edges = [
+ ["END", "MB", "MA", "PA", "START"],
+ ["END", "PB", "PA"],
+ ]
+
+ expected_state_ids = ["PA", "MA", "MB"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_join_rule_evasion(self):
+ events = [
+ FakeEvent(
+ id="JR",
+ sender=ALICE,
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rules": JoinRules.PRIVATE},
+ ),
+ FakeEvent(
+ id="ME",
+ sender=EVELYN,
+ type=EventTypes.Member,
+ state_key=EVELYN,
+ content={"membership": Membership.JOIN},
+ ),
+ ]
+
+ edges = [
+ ["END", "JR", "START"],
+ ["END", "ME", "START"],
+ ]
+
+ expected_state_ids = ["JR"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_offtopic_pl(self):
+ events = [
+ FakeEvent(
+ id="PA",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ }
+ },
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ CHARLIE: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="PC",
+ sender=CHARLIE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ CHARLIE: 0,
+ },
+ },
+ ),
+ ]
+
+ edges = [
+ ["END", "PC", "PB", "PA", "START"],
+ ["END", "PA"],
+ ]
+
+ expected_state_ids = ["PC"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_topic_basic(self):
+ events = [
+ FakeEvent(
+ id="T1",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA1",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T2",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA2",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 0,
+ },
+ },
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T3",
+ sender=BOB,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ ]
+
+ edges = [
+ ["END", "PA2", "T2", "PA1", "T1", "START"],
+ ["END", "T3", "PB", "PA1"],
+ ]
+
+ expected_state_ids = ["PA2", "T2"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_topic_reset(self):
+ events = [
+ FakeEvent(
+ id="T1",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T2",
+ sender=BOB,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="MB",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=BOB,
+ content={"membership": Membership.BAN},
+ ),
+ ]
+
+ edges = [
+ ["END", "MB", "T2", "PA", "T1", "START"],
+ ["END", "T1"],
+ ]
+
+ expected_state_ids = ["T1", "MB", "PA"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_topic(self):
+ events = [
+ FakeEvent(
+ id="T1",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA1",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T2",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA2",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 0,
+ },
+ },
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T3",
+ sender=BOB,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="MZ1",
+ sender=ZARA,
+ type=EventTypes.Message,
+ state_key=None,
+ content={},
+ ),
+ FakeEvent(
+ id="T4",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ ]
+
+ edges = [
+ ["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"],
+ ["END", "MZ1", "T3", "PB", "PA1"],
+ ]
+
+ expected_state_ids = ["T4", "PA2"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def do_check(self, events, edges, expected_state_ids):
+ """Take a list of events and edges and calculate the state of the
+ graph at END, and asserts it matches `expected_state_ids`
+
+ Args:
+ events (list[FakeEvent])
+ edges (list[list[str]]): A list of chains of event edges, e.g.
+ `[[A, B, C]]` are edges A->B and B->C.
+ expected_state_ids (list[str]): The expected state at END, (excluding
+ the keys that haven't changed since START).
+ """
+ # We want to sort the events into topological order for processing.
+ graph = {}
+
+ # node_id -> FakeEvent
+ fake_event_map = {}
+
+ for ev in itertools.chain(INITIAL_EVENTS, events):
+ graph[ev.node_id] = set()
+ fake_event_map[ev.node_id] = ev
+
+ for a, b in pairwise(INITIAL_EDGES):
+ graph[a].add(b)
+
+ for edge_list in edges:
+ for a, b in pairwise(edge_list):
+ graph[a].add(b)
+
+ # event_id -> FrozenEvent
+ event_map = {}
+ # node_id -> state
+ state_at_event = {}
+
+ # We copy the map as the sort consumes the graph
+ graph_copy = {k: set(v) for k, v in graph.items()}
+
+ for node_id in lexicographical_topological_sort(graph_copy, key=lambda e: e):
+ fake_event = fake_event_map[node_id]
+ event_id = fake_event.event_id
+
+ prev_events = list(graph[node_id])
+
+ if len(prev_events) == 0:
+ state_before = {}
+ elif len(prev_events) == 1:
+ state_before = dict(state_at_event[prev_events[0]])
+ else:
+ state_d = resolve_events_with_store(
+ [state_at_event[n] for n in prev_events],
+ event_map=event_map,
+ state_res_store=TestStateResolutionStore(event_map),
+ )
+
+ self.assertTrue(state_d.called)
+ state_before = state_d.result
+
+ state_after = dict(state_before)
+ if fake_event.state_key is not None:
+ state_after[(fake_event.type, fake_event.state_key)] = event_id
+
+ auth_types = set(auth_types_for_event(fake_event))
+
+ auth_events = []
+ for key in auth_types:
+ if key in state_before:
+ auth_events.append(state_before[key])
+
+ event = fake_event.to_event(auth_events, prev_events)
+
+ state_at_event[node_id] = state_after
+ event_map[event_id] = event
+
+ expected_state = {}
+ for node_id in expected_state_ids:
+ # expected_state_ids are node IDs rather than event IDs,
+ # so we have to convert
+ event_id = EventID(node_id, "example.com").to_string()
+ event = event_map[event_id]
+
+ key = (event.type, event.state_key)
+
+ expected_state[key] = event_id
+
+ start_state = state_at_event["START"]
+ end_state = {
+ key: value
+ for key, value in state_at_event["END"].items()
+ if key in expected_state or start_state.get(key) != value
+ }
+
+ self.assertEqual(expected_state, end_state)
+
+
+class LexicographicalTestCase(unittest.TestCase):
+ def test_simple(self):
+ graph = {
+ "l": {"o"},
+ "m": {"n", "o"},
+ "n": {"o"},
+ "o": set(),
+ "p": {"o"},
+ }
+
+ res = list(lexicographical_topological_sort(graph, key=lambda x: x))
+
+ self.assertEqual(["o", "l", "n", "m", "p"], res)
+
+
+def pairwise(iterable):
+ "s -> (s0,s1), (s1,s2), (s2, s3), ..."
+ a, b = itertools.tee(iterable)
+ next(b, None)
+ return zip(a, b)
+
+
+@attr.s
+class TestStateResolutionStore(object):
+ event_map = attr.ib()
+
+ def get_events(self, event_ids, allow_rejected=False):
+ """Get events from the database
+
+ Args:
+ event_ids (list): The event_ids of the events to fetch
+ allow_rejected (bool): If True return rejected events.
+
+ Returns:
+ Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+ """
+
+ return {
+ eid: self.event_map[eid]
+ for eid in event_ids
+ if eid in self.event_map
+ }
+
+ def get_auth_chain(self, event_ids):
+ """Gets the full auth chain for a set of events (including rejected
+ events).
+
+ Includes the given event IDs in the result.
+
+ Note that:
+ 1. All events must be state events.
+ 2. For v1 rooms this may not have the full auth chain in the
+ presence of rejected events
+
+ Args:
+ event_ids (list): The event IDs of the events to fetch the auth
+ chain for. Must be state events.
+
+ Returns:
+ Deferred[list[str]]: List of event IDs of the auth chain.
+ """
+
+ # Simple DFS for auth chain
+ result = set()
+ stack = list(event_ids)
+ while stack:
+ event_id = stack.pop()
+ if event_id in result:
+ continue
+
+ result.add(event_id)
+
+ event = self.event_map[event_id]
+ for aid, _ in event.auth_events:
+ stack.append(aid)
+
+ return list(result)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index c893990454..3f0083831b 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -37,18 +37,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.as_yaml_files = []
- config = Mock(
- app_service_config_files=self.as_yaml_files,
- event_cache_size=1,
- password_providers=[],
- )
hs = yield setup_test_homeserver(
- self.addCleanup,
- config=config,
- federation_sender=Mock(),
- federation_client=Mock(),
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
+ hs.config.app_service_config_files = self.as_yaml_files
+ hs.config.event_cache_size = 1
+ hs.config.password_providers = []
+
self.as_token = "token1"
self.as_url = "some_url"
self.as_id = "as1"
@@ -58,7 +54,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
- self.store = ApplicationServiceStore(None, hs)
+ self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
def tearDown(self):
# TODO: suboptimal that we need to create files for tests!
@@ -105,18 +101,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
def setUp(self):
self.as_yaml_files = []
- config = Mock(
- app_service_config_files=self.as_yaml_files,
- event_cache_size=1,
- password_providers=[],
- )
hs = yield setup_test_homeserver(
- self.addCleanup,
- config=config,
- federation_sender=Mock(),
- federation_client=Mock(),
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
+
+ hs.config.app_service_config_files = self.as_yaml_files
+ hs.config.event_cache_size = 1
+ hs.config.password_providers = []
+
self.db_pool = hs.get_db_pool()
+ self.engine = hs.database_engine
self.as_list = [
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
@@ -129,7 +123,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files = []
- self.store = TestTransactionStore(None, hs)
+ self.store = TestTransactionStore(hs.get_db_conn(), hs)
def _add_service(self, url, as_token, id):
as_yaml = dict(
@@ -146,29 +140,35 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files.append(as_token)
def _set_state(self, id, state, txn=None):
- return self.db_pool.runQuery(
- "INSERT INTO application_services_state(as_id, state, last_txn) "
- "VALUES(?,?,?)",
+ return self.db_pool.runOperation(
+ self.engine.convert_param_style(
+ "INSERT INTO application_services_state(as_id, state, last_txn) "
+ "VALUES(?,?,?)"
+ ),
(id, state, txn),
)
def _insert_txn(self, as_id, txn_id, events):
- return self.db_pool.runQuery(
- "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
- "VALUES(?,?,?)",
+ return self.db_pool.runOperation(
+ self.engine.convert_param_style(
+ "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
+ "VALUES(?,?,?)"
+ ),
(as_id, txn_id, json.dumps([e.event_id for e in events])),
)
def _set_last_txn(self, as_id, txn_id):
- return self.db_pool.runQuery(
- "INSERT INTO application_services_state(as_id, last_txn, state) "
- "VALUES(?,?,?)",
+ return self.db_pool.runOperation(
+ self.engine.convert_param_style(
+ "INSERT INTO application_services_state(as_id, last_txn, state) "
+ "VALUES(?,?,?)"
+ ),
(as_id, txn_id, ApplicationServiceState.UP),
)
@defer.inlineCallbacks
def test_get_appservice_state_none(self):
- service = Mock(id=999)
+ service = Mock(id="999")
state = yield self.store.get_appservice_state(service)
self.assertEquals(None, state)
@@ -200,7 +200,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
rows = yield self.db_pool.runQuery(
- "SELECT as_id FROM application_services_state WHERE state=?",
+ self.engine.convert_param_style(
+ "SELECT as_id FROM application_services_state WHERE state=?"
+ ),
(ApplicationServiceState.DOWN,),
)
self.assertEquals(service.id, rows[0][0])
@@ -212,7 +214,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
rows = yield self.db_pool.runQuery(
- "SELECT as_id FROM application_services_state WHERE state=?",
+ self.engine.convert_param_style(
+ "SELECT as_id FROM application_services_state WHERE state=?"
+ ),
(ApplicationServiceState.UP,),
)
self.assertEquals(service.id, rows[0][0])
@@ -279,14 +283,19 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
res = yield self.db_pool.runQuery(
- "SELECT last_txn FROM application_services_state WHERE as_id=?",
+ self.engine.convert_param_style(
+ "SELECT last_txn FROM application_services_state WHERE as_id=?"
+ ),
(service.id,),
)
self.assertEquals(1, len(res))
self.assertEquals(txn_id, res[0][0])
res = yield self.db_pool.runQuery(
- "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
+ self.engine.convert_param_style(
+ "SELECT * FROM application_services_txns WHERE txn_id=?"
+ ),
+ (txn_id,),
)
self.assertEquals(0, len(res))
@@ -300,7 +309,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
res = yield self.db_pool.runQuery(
- "SELECT last_txn, state FROM application_services_state WHERE " "as_id=?",
+ self.engine.convert_param_style(
+ "SELECT last_txn, state FROM application_services_state WHERE as_id=?"
+ ),
(service.id,),
)
self.assertEquals(1, len(res))
@@ -308,7 +319,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.assertEquals(ApplicationServiceState.UP, res[0][1])
res = yield self.db_pool.runQuery(
- "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
+ self.engine.convert_param_style(
+ "SELECT * FROM application_services_txns WHERE txn_id=?"
+ ),
+ (txn_id,),
)
self.assertEquals(0, len(res))
@@ -394,37 +408,31 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(suffix="1")
f2 = self._write_config(suffix="2")
- config = Mock(
- app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
- )
hs = yield setup_test_homeserver(
- self.addCleanup,
- config=config,
- datastore=Mock(),
- federation_sender=Mock(),
- federation_client=Mock(),
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
- ApplicationServiceStore(None, hs)
+ hs.config.app_service_config_files = [f1, f2]
+ hs.config.event_cache_size = 1
+ hs.config.password_providers = []
+
+ ApplicationServiceStore(hs.get_db_conn(), hs)
@defer.inlineCallbacks
def test_duplicate_ids(self):
f1 = self._write_config(id="id", suffix="1")
f2 = self._write_config(id="id", suffix="2")
- config = Mock(
- app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
- )
hs = yield setup_test_homeserver(
- self.addCleanup,
- config=config,
- datastore=Mock(),
- federation_sender=Mock(),
- federation_client=Mock(),
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
+ hs.config.app_service_config_files = [f1, f2]
+ hs.config.event_cache_size = 1
+ hs.config.password_providers = []
+
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(None, hs)
+ ApplicationServiceStore(hs.get_db_conn(), hs)
e = cm.exception
self.assertIn(f1, str(e))
@@ -436,19 +444,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(as_token="as_token", suffix="1")
f2 = self._write_config(as_token="as_token", suffix="2")
- config = Mock(
- app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
- )
hs = yield setup_test_homeserver(
- self.addCleanup,
- config=config,
- datastore=Mock(),
- federation_sender=Mock(),
- federation_client=Mock(),
+ self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
+ hs.config.app_service_config_files = [f1, f2]
+ hs.config.event_cache_size = 1
+ hs.config.password_providers = []
+
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(None, hs)
+ ApplicationServiceStore(hs.get_db_conn(), hs)
e = cm.exception
self.assertIn(f1, str(e))
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 7cb5f0e4cf..829f47d2e8 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -20,11 +20,11 @@ from mock import Mock
from twisted.internet import defer
-from synapse.server import HomeServer
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import create_engine
from tests import unittest
+from tests.utils import TestHomeServer
class SQLBaseStoreTestCase(unittest.TestCase):
@@ -51,7 +51,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
config = Mock()
config.event_cache_size = 1
config.database_config = {"name": "sqlite3"}
- hs = HomeServer(
+ hs = TestHomeServer(
"test",
db_pool=self.db_pool,
config=config,
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index c2e88bdbaf..4577e9422b 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,35 +13,41 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from mock import Mock
from twisted.internet import defer
-import tests.unittest
-import tests.utils
+from synapse.http.site import XForwardedForRequest
+from synapse.rest.client.v1 import admin, login
+
+from tests import unittest
-class ClientIpStoreTestCase(tests.unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super(ClientIpStoreTestCase, self).__init__(*args, **kwargs)
- self.store = None # type: synapse.storage.DataStore
- self.clock = None # type: tests.utils.MockClock
+class ClientIpStoreTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+ return hs
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+ def prepare(self, hs, reactor, clock):
self.store = self.hs.get_datastore()
- self.clock = self.hs.get_clock()
- @defer.inlineCallbacks
def test_insert_new_client_ip(self):
- self.clock.now = 12345678
+ self.reactor.advance(12345678)
+
user_id = "@user:id"
- yield self.store.insert_client_ip(
- user_id, "access_token", "ip", "user_agent", "device_id"
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id"
+ )
)
- result = yield self.store.get_last_client_ip_by_device(user_id, "device_id")
+ # Trigger the storage loop
+ self.reactor.advance(10)
+
+ result = self.get_success(
+ self.store.get_last_client_ip_by_device(user_id, "device_id")
+ )
r = result[(user_id, "device_id")]
self.assertDictContainsSubset(
@@ -55,18 +62,18 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
r,
)
- @defer.inlineCallbacks
def test_disabled_monthly_active_user(self):
self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50
user_id = "@user:server"
- yield self.store.insert_client_ip(
- user_id, "access_token", "ip", "user_agent", "device_id"
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id"
+ )
)
- active = yield self.store.user_last_seen_monthly_active(user_id)
+ active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
- @defer.inlineCallbacks
def test_adding_monthly_active_user_when_full(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
@@ -76,40 +83,119 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(lots_of_users)
)
- yield self.store.insert_client_ip(
- user_id, "access_token", "ip", "user_agent", "device_id"
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id"
+ )
)
- active = yield self.store.user_last_seen_monthly_active(user_id)
+ active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
- @defer.inlineCallbacks
def test_adding_monthly_active_user_when_space(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
user_id = "@user:server"
- active = yield self.store.user_last_seen_monthly_active(user_id)
+ active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
- yield self.store.insert_client_ip(
- user_id, "access_token", "ip", "user_agent", "device_id"
+ # Trigger the saving loop
+ self.reactor.advance(10)
+
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id"
+ )
)
- active = yield self.store.user_last_seen_monthly_active(user_id)
+ active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active)
- @defer.inlineCallbacks
def test_updating_monthly_active_user_when_space(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
user_id = "@user:server"
+ self.get_success(
+ self.store.register(user_id=user_id, token="123", password_hash=None)
+ )
- active = yield self.store.user_last_seen_monthly_active(user_id)
+ active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
- yield self.store.insert_client_ip(
- user_id, "access_token", "ip", "user_agent", "device_id"
- )
- yield self.store.insert_client_ip(
- user_id, "access_token", "ip", "user_agent", "device_id"
+ # Trigger the saving loop
+ self.reactor.advance(10)
+
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", "device_id"
+ )
)
- active = yield self.store.user_last_seen_monthly_active(user_id)
+ active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active)
+
+
+class ClientIpAuthTestCase(unittest.HomeserverTestCase):
+
+ servlets = [admin.register_servlets, login.register_servlets]
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+ return hs
+
+ def prepare(self, hs, reactor, clock):
+ self.store = self.hs.get_datastore()
+ self.user_id = self.register_user("bob", "abc123", True)
+
+ def test_request_with_xforwarded(self):
+ """
+ The IP in X-Forwarded-For is entered into the client IPs table.
+ """
+ self._runtest(
+ {b"X-Forwarded-For": b"127.9.0.1"},
+ "127.9.0.1",
+ {"request": XForwardedForRequest},
+ )
+
+ def test_request_from_getPeer(self):
+ """
+ The IP returned by getPeer is entered into the client IPs table, if
+ there's no X-Forwarded-For header.
+ """
+ self._runtest({}, "127.0.0.1", {})
+
+ def _runtest(self, headers, expected_ip, make_request_args):
+ device_id = "bleb"
+
+ access_token = self.login("bob", "abc123", device_id=device_id)
+
+ # Advance to a known time
+ self.reactor.advance(123456 - self.reactor.seconds())
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/admin/users/" + self.user_id,
+ access_token=access_token,
+ **make_request_args
+ )
+ request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")
+
+ # Add the optional headers
+ for h, v in headers.items():
+ request.requestHeaders.addRawHeader(h, v)
+ self.render(request)
+
+ # Advance so the save loop occurs
+ self.reactor.advance(100)
+
+ result = self.get_success(
+ self.store.get_last_client_ip_by_device(self.user_id, device_id)
+ )
+ r = result[(self.user_id, device_id)]
+ self.assertDictContainsSubset(
+ {
+ "user_id": self.user_id,
+ "device_id": device_id,
+ "ip": expected_ip,
+ "user_agent": "Mozzila pizza",
+ "last_seen": 123456100,
+ },
+ r,
+ )
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index b4510c1c8d..4e128e1047 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -16,7 +16,6 @@
from twisted.internet import defer
-from synapse.storage.directory import DirectoryStore
from synapse.types import RoomAlias, RoomID
from tests import unittest
@@ -28,7 +27,7 @@ class DirectoryStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
- self.store = DirectoryStore(None, hs)
+ self.store = hs.get_datastore()
self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 2fdf34fdf6..0d4e74d637 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -37,10 +37,10 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
(
"INSERT INTO events ("
" room_id, event_id, type, depth, topological_ordering,"
- " content, processed, outlier) "
- "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)"
+ " content, processed, outlier, stream_ordering) "
+ "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?, ?)"
),
- (room_id, event_id, i, i, True, False),
+ (room_id, event_id, i, i, True, False, i),
)
txn.execute(
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index f2ed866ae7..832e379a83 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -12,26 +12,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from mock import Mock
from twisted.internet import defer
-import tests.unittest
-import tests.utils
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase
FORTY_DAYS = 40 * 24 * 60 * 60
-class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super(MonthlyActiveUsersTestCase, self).__init__(*args, **kwargs)
+class MonthlyActiveUsersTestCase(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup)
- self.store = self.hs.get_datastore()
+ hs = self.setup_test_homeserver()
+ self.store = hs.get_datastore()
+ hs.config.limit_usage_by_mau = True
+ hs.config.max_mau_value = 50
+ # Advance the clock a bit
+ reactor.advance(FORTY_DAYS)
+
+ return hs
- @defer.inlineCallbacks
def test_initialise_reserved_users(self):
self.hs.config.max_mau_value = 5
user1 = "@user1:server"
@@ -44,88 +45,178 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
]
user_num = len(threepids)
- yield self.store.register(user_id=user1, token="123", password_hash=None)
-
- yield self.store.register(user_id=user2, token="456", password_hash=None)
+ self.store.register(user_id=user1, token="123", password_hash=None)
+ self.store.register(user_id=user2, token="456", password_hash=None)
+ self.pump()
now = int(self.hs.get_clock().time_msec())
- yield self.store.user_add_threepid(user1, "email", user1_email, now, now)
- yield self.store.user_add_threepid(user2, "email", user2_email, now, now)
- yield self.store.initialise_reserved_users(threepids)
+ self.store.user_add_threepid(user1, "email", user1_email, now, now)
+ self.store.user_add_threepid(user2, "email", user2_email, now, now)
- active_count = yield self.store.get_monthly_active_count()
+ self.store.runInteraction(
+ "initialise", self.store._initialise_reserved_users, threepids
+ )
+ self.pump()
+
+ active_count = self.store.get_monthly_active_count()
# Test total counts
- self.assertEquals(active_count, user_num)
+ self.assertEquals(self.get_success(active_count), user_num)
# Test user is marked as active
-
- timestamp = yield self.store.user_last_seen_monthly_active(user1)
- self.assertTrue(timestamp)
- timestamp = yield self.store.user_last_seen_monthly_active(user2)
- self.assertTrue(timestamp)
+ timestamp = self.store.user_last_seen_monthly_active(user1)
+ self.assertTrue(self.get_success(timestamp))
+ timestamp = self.store.user_last_seen_monthly_active(user2)
+ self.assertTrue(self.get_success(timestamp))
# Test that users are never removed from the db.
self.hs.config.max_mau_value = 0
- self.hs.get_clock().advance_time(FORTY_DAYS)
+ self.reactor.advance(FORTY_DAYS)
- yield self.store.reap_monthly_active_users()
+ self.store.reap_monthly_active_users()
+ self.pump()
- active_count = yield self.store.get_monthly_active_count()
- self.assertEquals(active_count, user_num)
+ active_count = self.store.get_monthly_active_count()
+ self.assertEquals(self.get_success(active_count), user_num)
- # Test that regalar users are removed from the db
+ # Test that regular users are removed from the db
ru_count = 2
- yield self.store.upsert_monthly_active_user("@ru1:server")
- yield self.store.upsert_monthly_active_user("@ru2:server")
- active_count = yield self.store.get_monthly_active_count()
+ self.store.upsert_monthly_active_user("@ru1:server")
+ self.store.upsert_monthly_active_user("@ru2:server")
+ self.pump()
- self.assertEqual(active_count, user_num + ru_count)
+ active_count = self.store.get_monthly_active_count()
+ self.assertEqual(self.get_success(active_count), user_num + ru_count)
self.hs.config.max_mau_value = user_num
- yield self.store.reap_monthly_active_users()
+ self.store.reap_monthly_active_users()
+ self.pump()
- active_count = yield self.store.get_monthly_active_count()
- self.assertEquals(active_count, user_num)
+ active_count = self.store.get_monthly_active_count()
+ self.assertEquals(self.get_success(active_count), user_num)
- @defer.inlineCallbacks
def test_can_insert_and_count_mau(self):
- count = yield self.store.get_monthly_active_count()
- self.assertEqual(0, count)
+ count = self.store.get_monthly_active_count()
+ self.assertEqual(0, self.get_success(count))
- yield self.store.upsert_monthly_active_user("@user:server")
- count = yield self.store.get_monthly_active_count()
+ self.store.upsert_monthly_active_user("@user:server")
+ self.pump()
- self.assertEqual(1, count)
+ count = self.store.get_monthly_active_count()
+ self.assertEqual(1, self.get_success(count))
- @defer.inlineCallbacks
def test_user_last_seen_monthly_active(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
user_id3 = "@user3:server"
- result = yield self.store.user_last_seen_monthly_active(user_id1)
- self.assertFalse(result == 0)
- yield self.store.upsert_monthly_active_user(user_id1)
- yield self.store.upsert_monthly_active_user(user_id2)
- result = yield self.store.user_last_seen_monthly_active(user_id1)
- self.assertTrue(result > 0)
- result = yield self.store.user_last_seen_monthly_active(user_id3)
- self.assertFalse(result == 0)
+ result = self.store.user_last_seen_monthly_active(user_id1)
+ self.assertFalse(self.get_success(result) == 0)
+
+ self.store.upsert_monthly_active_user(user_id1)
+ self.store.upsert_monthly_active_user(user_id2)
+ self.pump()
+
+ result = self.store.user_last_seen_monthly_active(user_id1)
+ self.assertGreater(self.get_success(result), 0)
+
+ result = self.store.user_last_seen_monthly_active(user_id3)
+ self.assertNotEqual(self.get_success(result), 0)
- @defer.inlineCallbacks
def test_reap_monthly_active_users(self):
self.hs.config.max_mau_value = 5
initial_users = 10
for i in range(initial_users):
- yield self.store.upsert_monthly_active_user("@user%d:server" % i)
- count = yield self.store.get_monthly_active_count()
- self.assertTrue(count, initial_users)
- yield self.store.reap_monthly_active_users()
- count = yield self.store.get_monthly_active_count()
- self.assertEquals(count, initial_users - self.hs.config.max_mau_value)
-
- self.hs.get_clock().advance_time(FORTY_DAYS)
- yield self.store.reap_monthly_active_users()
- count = yield self.store.get_monthly_active_count()
- self.assertEquals(count, 0)
+ self.store.upsert_monthly_active_user("@user%d:server" % i)
+ self.pump()
+
+ count = self.store.get_monthly_active_count()
+ self.assertTrue(self.get_success(count), initial_users)
+
+ self.store.reap_monthly_active_users()
+ self.pump()
+ count = self.store.get_monthly_active_count()
+ self.assertEquals(
+ self.get_success(count), initial_users - self.hs.config.max_mau_value
+ )
+
+ self.reactor.advance(FORTY_DAYS)
+ self.store.reap_monthly_active_users()
+ self.pump()
+
+ count = self.store.get_monthly_active_count()
+ self.assertEquals(self.get_success(count), 0)
+
+ def test_populate_monthly_users_is_guest(self):
+ # Test that guest users are not added to mau list
+ user_id = "user_id"
+ self.store.register(
+ user_id=user_id, token="123", password_hash=None, make_guest=True
+ )
+ self.store.upsert_monthly_active_user = Mock()
+ self.store.populate_monthly_active_users(user_id)
+ self.pump()
+ self.store.upsert_monthly_active_user.assert_not_called()
+
+ def test_populate_monthly_users_should_update(self):
+ self.store.upsert_monthly_active_user = Mock()
+
+ self.store.is_trial_user = Mock(
+ return_value=defer.succeed(False)
+ )
+
+ self.store.user_last_seen_monthly_active = Mock(
+ return_value=defer.succeed(None)
+ )
+ self.store.populate_monthly_active_users('user_id')
+ self.pump()
+ self.store.upsert_monthly_active_user.assert_called_once()
+
+ def test_populate_monthly_users_should_not_update(self):
+ self.store.upsert_monthly_active_user = Mock()
+
+ self.store.is_trial_user = Mock(
+ return_value=defer.succeed(False)
+ )
+ self.store.user_last_seen_monthly_active = Mock(
+ return_value=defer.succeed(
+ self.hs.get_clock().time_msec()
+ )
+ )
+ self.store.populate_monthly_active_users('user_id')
+ self.pump()
+ self.store.upsert_monthly_active_user.assert_not_called()
+
+ def test_get_reserved_real_user_account(self):
+ # Test no reserved users, or reserved threepids
+ count = self.store.get_registered_reserved_users_count()
+ self.assertEquals(self.get_success(count), 0)
+ # Test reserved users but no registered users
+
+ user1 = '@user1:example.com'
+ user2 = '@user2:example.com'
+ user1_email = 'user1@example.com'
+ user2_email = 'user2@example.com'
+ threepids = [
+ {'medium': 'email', 'address': user1_email},
+ {'medium': 'email', 'address': user2_email},
+ ]
+ self.hs.config.mau_limits_reserved_threepids = threepids
+ self.store.runInteraction(
+ "initialise", self.store._initialise_reserved_users, threepids
+ )
+
+ self.pump()
+ count = self.store.get_registered_reserved_users_count()
+ self.assertEquals(self.get_success(count), 0)
+
+ # Test reserved registed users
+ self.store.register(user_id=user1, token="123", password_hash=None)
+ self.store.register(user_id=user2, token="456", password_hash=None)
+ self.pump()
+
+ now = int(self.hs.get_clock().time_msec())
+ self.store.user_add_threepid(user1, "email", user1_email, now, now)
+ self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ count = self.store.get_registered_reserved_users_count()
+ self.assertEquals(self.get_success(count), len(threepids))
diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py
index b5b58ff660..c7a63f39b9 100644
--- a/tests/storage/test_presence.py
+++ b/tests/storage/test_presence.py
@@ -16,19 +16,18 @@
from twisted.internet import defer
-from synapse.storage.presence import PresenceStore
from synapse.types import UserID
from tests import unittest
-from tests.utils import MockClock, setup_test_homeserver
+from tests.utils import setup_test_homeserver
class PresenceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup, clock=MockClock())
+ hs = yield setup_test_homeserver(self.addCleanup)
- self.store = PresenceStore(None, hs)
+ self.store = hs.get_datastore()
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index dc3a2fd976..c125a0d797 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -28,7 +28,7 @@ class ProfileStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
- self.store = ProfileStore(None, hs)
+ self.store = ProfileStore(hs.get_db_conn(), hs)
self.u_frank = UserID.from_string("@frank:test")
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
new file mode 100644
index 0000000000..f671599cb8
--- /dev/null
+++ b/tests/storage/test_purge.py
@@ -0,0 +1,106 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest.client.v1 import room
+
+from tests.unittest import HomeserverTestCase
+
+
+class PurgeTests(HomeserverTestCase):
+
+ user_id = "@red:server"
+ servlets = [room.register_servlets]
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver("server", http_client=None)
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ self.room_id = self.helper.create_room_as(self.user_id)
+
+ def test_purge(self):
+ """
+ Purging a room will delete everything before the topological point.
+ """
+ # Send four messages to the room
+ first = self.helper.send(self.room_id, body="test1")
+ second = self.helper.send(self.room_id, body="test2")
+ third = self.helper.send(self.room_id, body="test3")
+ last = self.helper.send(self.room_id, body="test4")
+
+ storage = self.hs.get_datastore()
+
+ # Get the topological token
+ event = storage.get_topological_token_for_event(last["event_id"])
+ self.pump()
+ event = self.successResultOf(event)
+
+ # Purge everything before this topological token
+ purge = storage.purge_history(self.room_id, event, True)
+ self.pump()
+ self.assertEqual(self.successResultOf(purge), None)
+
+ # Try and get the events
+ get_first = storage.get_event(first["event_id"])
+ get_second = storage.get_event(second["event_id"])
+ get_third = storage.get_event(third["event_id"])
+ get_last = storage.get_event(last["event_id"])
+ self.pump()
+
+ # 1-3 should fail and last will succeed, meaning that 1-3 are deleted
+ # and last is not.
+ self.failureResultOf(get_first)
+ self.failureResultOf(get_second)
+ self.failureResultOf(get_third)
+ self.successResultOf(get_last)
+
+ def test_purge_wont_delete_extrems(self):
+ """
+ Purging a room will delete everything before the topological point.
+ """
+ # Send four messages to the room
+ first = self.helper.send(self.room_id, body="test1")
+ second = self.helper.send(self.room_id, body="test2")
+ third = self.helper.send(self.room_id, body="test3")
+ last = self.helper.send(self.room_id, body="test4")
+
+ storage = self.hs.get_datastore()
+
+ # Set the topological token higher than it should be
+ event = storage.get_topological_token_for_event(last["event_id"])
+ self.pump()
+ event = self.successResultOf(event)
+ event = "t{}-{}".format(
+ *list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
+ )
+
+ # Purge everything before this topological token
+ purge = storage.purge_history(self.room_id, event, True)
+ self.pump()
+ f = self.failureResultOf(purge)
+ self.assertIn("greater than forward", f.value.args[0])
+
+ # Try and get the events
+ get_first = storage.get_event(first["event_id"])
+ get_second = storage.get_event(second["event_id"])
+ get_third = storage.get_event(third["event_id"])
+ get_last = storage.get_event(last["event_id"])
+ self.pump()
+
+ # Nothing is deleted.
+ self.successResultOf(get_first)
+ self.successResultOf(get_second)
+ self.successResultOf(get_third)
+ self.successResultOf(get_last)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index c4e9fb72bf..02bf975fbf 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -22,7 +22,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.types import RoomID, UserID
from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.utils import create_room, setup_test_homeserver
class RedactionTestCase(unittest.TestCase):
@@ -41,6 +41,8 @@ class RedactionTestCase(unittest.TestCase):
self.room1 = RoomID.from_string("!abc123:test")
+ yield create_room(hs, self.room1.to_string(), self.u_alice.to_string())
+
self.depth = 1
@defer.inlineCallbacks
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4eda122edc..3dfb7b903a 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -46,6 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"consent_version": None,
"consent_server_notice_sent": None,
"appservice_id": None,
+ "creation_ts": 1000,
},
(yield self.store.get_user_by_id(self.user_id)),
)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index c83ef60062..978c66133d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -22,7 +22,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.types import RoomID, UserID
from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.utils import create_room, setup_test_homeserver
class RoomMemberStoreTestCase(unittest.TestCase):
@@ -45,6 +45,8 @@ class RoomMemberStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abc123:test")
+ yield create_room(hs, self.room.to_string(), self.u_alice.to_string())
+
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership, replaces_state=None):
builder = self.event_builder_factory.new(
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index ebfd969b36..086a39d834 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -18,6 +18,7 @@ import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
+from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
import tests.unittest
@@ -75,6 +76,45 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(len(s1), len(s2))
@defer.inlineCallbacks
+ def test_get_state_groups_ids(self):
+ e1 = yield self.inject_state_event(
+ self.room, self.u_alice, EventTypes.Create, '', {}
+ )
+ e2 = yield self.inject_state_event(
+ self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
+ )
+
+ state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id])
+ self.assertEqual(len(state_group_map), 1)
+ state_map = list(state_group_map.values())[0]
+ self.assertDictEqual(
+ state_map,
+ {
+ (EventTypes.Create, ''): e1.event_id,
+ (EventTypes.Name, ''): e2.event_id,
+ },
+ )
+
+ @defer.inlineCallbacks
+ def test_get_state_groups(self):
+ e1 = yield self.inject_state_event(
+ self.room, self.u_alice, EventTypes.Create, '', {}
+ )
+ e2 = yield self.inject_state_event(
+ self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
+ )
+
+ state_group_map = yield self.store.get_state_groups(
+ self.room, [e2.event_id])
+ self.assertEqual(len(state_group_map), 1)
+ state_list = list(state_group_map.values())[0]
+
+ self.assertEqual(
+ {ev.event_id for ev in state_list},
+ {e1.event_id, e2.event_id},
+ )
+
+ @defer.inlineCallbacks
def test_get_state_for_event(self):
# this defaults to a linear DAG as each new injection defaults to whatever
@@ -109,7 +149,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we get the full state as of the final event
state = yield self.store.get_state_for_event(
- e5.event_id, None, filtered_types=None
+ e5.event_id,
)
self.assertIsNotNone(e4)
@@ -127,33 +167,35 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can filter to the m.room.name event (with a '' state key)
state = yield self.store.get_state_for_event(
- e5.event_id, [(EventTypes.Name, '')], filtered_types=None
+ e5.event_id, StateFilter.from_types([(EventTypes.Name, '')])
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
state = yield self.store.get_state_for_event(
- e5.event_id, [(EventTypes.Name, None)], filtered_types=None
+ e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
state = yield self.store.get_state_for_event(
- e5.event_id, [(EventTypes.Member, None)], filtered_types=None
+ e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
self.assertStateMapEqual(
{(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
)
- # check we can use filtered_types to grab a specific room member
- # without filtering out the other event types
+ # check we can grab a specific room member without filtering out the
+ # other event types
state = yield self.store.get_state_for_event(
e5.event_id,
- [(EventTypes.Member, self.u_alice.to_string())],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: {self.u_alice.to_string()}},
+ include_others=True,
+ )
)
self.assertStateMapEqual(
@@ -165,10 +207,12 @@ class StateStoreTestCase(tests.unittest.TestCase):
state,
)
- # check that types=[], filtered_types=[EventTypes.Member]
- # doesn't return all members
+ # check that we can grab everything except members
state = yield self.store.get_state_for_event(
- e5.event_id, [], filtered_types=[EventTypes.Member]
+ e5.event_id, state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
)
self.assertStateMapEqual(
@@ -176,16 +220,21 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
#######################################################
- # _get_some_state_from_cache tests against a full cache
+ # _get_state_for_group_using_cache tests against a full cache
#######################################################
room_id = self.room.to_string()
group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
group = list(group_ids.keys())[0]
- # test _get_some_state_from_cache correctly filters out members with types=[]
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- group, [], filtered_types=[EventTypes.Member]
+ # test _get_state_for_group_using_cache correctly filters out members
+ # with types=[]
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache, group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
@@ -197,9 +246,27 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- # test _get_some_state_from_cache correctly filters in members with wildcard types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_members_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual({}, state_dict)
+
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with wildcard types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: None},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
@@ -207,6 +274,22 @@ class StateStoreTestCase(tests.unittest.TestCase):
{
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
+ },
+ state_dict,
+ )
+
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_members_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: None},
+ include_others=True,
+ ),
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual(
+ {
(e3.type, e3.state_key): e3.event_id,
# e4 is overwritten by e5
(e5.type, e5.state_key): e5.event_id,
@@ -214,11 +297,15 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- # test _get_some_state_from_cache correctly filters in members with specific types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with specific types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache,
group,
- [(EventTypes.Member, e5.state_key)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
@@ -226,15 +313,31 @@ class StateStoreTestCase(tests.unittest.TestCase):
{
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
- (e5.type, e5.state_key): e5.event_id,
},
state_dict,
)
- # test _get_some_state_from_cache correctly filters in members with specific types
- # and no filtered_types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- group, [(EventTypes.Member, e5.state_key)], filtered_types=None
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_members_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=True,
+ ),
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with specific types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_members_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=False,
+ ),
)
self.assertEqual(is_all, True)
@@ -254,9 +357,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
{
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
- (e3.type, e3.state_key): e3.event_id,
- # e4 is overwritten by e5
- (e5.type, e5.state_key): e5.event_id,
},
)
@@ -267,11 +367,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
key=group,
value=state_dict_ids,
# list fetched keys so it knows it's partial
- fetched_keys=(
- (e1.type, e1.state_key),
- (e3.type, e3.state_key),
- (e5.type, e5.state_key),
- ),
+ fetched_keys=((e1.type, e1.state_key),),
)
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
@@ -279,73 +375,118 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
self.assertEqual(is_all, False)
- self.assertEqual(
- known_absent,
- set(
- [
- (e1.type, e1.state_key),
- (e3.type, e3.state_key),
- (e5.type, e5.state_key),
- ]
- ),
- )
- self.assertDictEqual(
- state_dict_ids,
- {
- (e1.type, e1.state_key): e1.event_id,
- (e3.type, e3.state_key): e3.event_id,
- (e5.type, e5.state_key): e5.event_id,
- },
- )
+ self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
+ self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
############################################
# test that things work with a partial cache
- # test _get_some_state_from_cache correctly filters out members with types=[]
+ # test _get_state_for_group_using_cache correctly filters out members
+ # with types=[]
room_id = self.room.to_string()
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- group, [], filtered_types=[EventTypes.Member]
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache, group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- # test _get_some_state_from_cache correctly filters in members wildcard types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
+ room_id = self.room.to_string()
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_members_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual({}, state_dict)
+
+ # test _get_state_for_group_using_cache correctly filters in members
+ # wildcard types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: None},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, False)
+ self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_members_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: None},
+ include_others=True,
+ ),
+ )
+
+ self.assertEqual(is_all, True)
self.assertDictEqual(
{
- (e1.type, e1.state_key): e1.event_id,
(e3.type, e3.state_key): e3.event_id,
- # e4 is overwritten by e5
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
)
- # test _get_some_state_from_cache correctly filters in members with specific types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with specific types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache,
group,
- [(EventTypes.Member, e5.state_key)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, False)
- self.assertDictEqual(
- {
- (e1.type, e1.state_key): e1.event_id,
- (e5.type, e5.state_key): e5.event_id,
- },
- state_dict,
+ self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_members_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=True,
+ ),
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with specific types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=False,
+ ),
)
- # test _get_some_state_from_cache correctly filters in members with specific types
- # and no filtered_types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- group, [(EventTypes.Member, e5.state_key)], filtered_types=None
+ self.assertEqual(is_all, False)
+ self.assertDictEqual({}, state_dict)
+
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_members_cache,
+ group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=False,
+ ),
)
self.assertEqual(is_all, True)
diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py
new file mode 100644
index 0000000000..14169afa96
--- /dev/null
+++ b/tests/storage/test_transactions.py
@@ -0,0 +1,45 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tests.unittest import HomeserverTestCase
+
+
+class TransactionStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
+ self.store = homeserver.get_datastore()
+
+ def test_get_set_transactions(self):
+ """Tests that we can successfully get a non-existent entry for
+ destination retries, as well as testing tht we can set and get
+ correctly.
+ """
+ d = self.store.get_destination_retry_timings("example.com")
+ r = self.get_success(d)
+ self.assertIsNone(r)
+
+ d = self.store.set_destination_retry_timings("example.com", 50, 100)
+ self.get_success(d)
+
+ d = self.store.get_destination_retry_timings("example.com")
+ r = self.get_success(d)
+
+ self.assert_dict({"retry_last_ts": 50, "retry_interval": 100}, r)
+
+ def test_initial_set_transactions(self):
+ """Tests that we can successfully set the destination retries (there
+ was a bug around invalidating the cache that broke this)
+ """
+ d = self.store.set_destination_retry_timings("example.com", 50, 100)
+ self.get_success(d)
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index b46e0ea7e2..0dde1ab2fe 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -30,7 +30,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(self.addCleanup)
- self.store = UserDirectoryStore(None, self.hs)
+ self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs)
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 2540604fcc..952a0a7b51 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -6,6 +6,7 @@ from twisted.internet.defer import maybeDeferred, succeed
from synapse.events import FrozenEvent
from synapse.types import Requester, UserID
from synapse.util import Clock
+from synapse.util.logcontext import LoggingContext
from tests import unittest
from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
@@ -117,9 +118,10 @@ class MessageAcceptTests(unittest.TestCase):
}
)
- d = self.handler.on_receive_pdu(
- "test.serv", lying_event, sent_to_us_directly=True
- )
+ with LoggingContext(request="lying_event"):
+ d = self.handler.on_receive_pdu(
+ "test.serv", lying_event, sent_to_us_directly=True
+ )
# Step the reactor, so the database fetches come back
self.reactor.advance(1)
@@ -139,107 +141,3 @@ class MessageAcceptTests(unittest.TestCase):
self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
)
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
-
- def test_cant_hide_past_history(self):
- """
- If you send a message, you must be able to provide the direct
- prev_events that said event references.
- """
-
- def post_json(destination, path, data, headers=None, timeout=0):
- if path.startswith("/_matrix/federation/v1/get_missing_events/"):
- return {
- "events": [
- {
- "room_id": self.room_id,
- "sender": "@baduser:test.serv",
- "event_id": "three:test.serv",
- "depth": 1000,
- "origin_server_ts": 1,
- "type": "m.room.message",
- "origin": "test.serv",
- "content": "hewwo?",
- "auth_events": [],
- "prev_events": [("four:test.serv", {})],
- }
- ]
- }
-
- self.http_client.post_json = post_json
-
- def get_json(destination, path, args, headers=None):
- if path.startswith("/_matrix/federation/v1/state_ids/"):
- d = self.successResultOf(
- self.homeserver.datastore.get_state_ids_for_event("one:test.serv")
- )
-
- return succeed(
- {
- "pdu_ids": [
- y
- for x, y in d.items()
- if x == ("m.room.member", "@us:test")
- ],
- "auth_chain_ids": list(d.values()),
- }
- )
-
- self.http_client.get_json = get_json
-
- # Figure out what the most recent event is
- most_recent = self.successResultOf(
- maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
- )
- )[0]
-
- # Make a good event
- good_event = FrozenEvent(
- {
- "room_id": self.room_id,
- "sender": "@baduser:test.serv",
- "event_id": "one:test.serv",
- "depth": 1000,
- "origin_server_ts": 1,
- "type": "m.room.message",
- "origin": "test.serv",
- "content": "hewwo?",
- "auth_events": [],
- "prev_events": [(most_recent, {})],
- }
- )
-
- d = self.handler.on_receive_pdu(
- "test.serv", good_event, sent_to_us_directly=True
- )
- self.reactor.advance(1)
- self.assertEqual(self.successResultOf(d), None)
-
- bad_event = FrozenEvent(
- {
- "room_id": self.room_id,
- "sender": "@baduser:test.serv",
- "event_id": "two:test.serv",
- "depth": 1000,
- "origin_server_ts": 1,
- "type": "m.room.message",
- "origin": "test.serv",
- "content": "hewwo?",
- "auth_events": [],
- "prev_events": [("one:test.serv", {}), ("three:test.serv", {})],
- }
- )
-
- d = self.handler.on_receive_pdu(
- "test.serv", bad_event, sent_to_us_directly=True
- )
- self.reactor.advance(1)
-
- extrem = maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
- )
- self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv")
-
- state = self.homeserver.get_state_handler().get_current_state_ids(self.room_id)
- self.reactor.advance(1)
- self.assertIn(("m.room.member", "@us:test"), self.successResultOf(state).keys())
diff --git a/tests/test_mau.py b/tests/test_mau.py
new file mode 100644
index 0000000000..bdbacb8448
--- /dev/null
+++ b/tests/test_mau.py
@@ -0,0 +1,217 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests REST events for /rooms paths."""
+
+import json
+
+from mock import Mock, NonCallableMock
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes, HttpResponseException, SynapseError
+from synapse.http.server import JsonResource
+from synapse.rest.client.v2_alpha import register, sync
+from synapse.util import Clock
+
+from tests import unittest
+from tests.server import (
+ ThreadedMemoryReactorClock,
+ make_request,
+ render,
+ setup_test_homeserver,
+)
+
+
+class TestMauLimit(unittest.TestCase):
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+ self.clock = Clock(self.reactor)
+
+ self.hs = setup_test_homeserver(
+ self.addCleanup,
+ "red",
+ http_client=None,
+ clock=self.clock,
+ reactor=self.reactor,
+ federation_client=Mock(),
+ ratelimiter=NonCallableMock(spec_set=["send_message"]),
+ )
+
+ self.store = self.hs.get_datastore()
+
+ self.hs.config.registrations_require_3pid = []
+ self.hs.config.enable_registration_captcha = False
+ self.hs.config.recaptcha_public_key = []
+
+ self.hs.config.limit_usage_by_mau = True
+ self.hs.config.hs_disabled = False
+ self.hs.config.max_mau_value = 2
+ self.hs.config.mau_trial_days = 0
+ self.hs.config.server_notices_mxid = "@server:red"
+ self.hs.config.server_notices_mxid_display_name = None
+ self.hs.config.server_notices_mxid_avatar_url = None
+ self.hs.config.server_notices_room_name = "Test Server Notice Room"
+
+ self.resource = JsonResource(self.hs)
+ register.register_servlets(self.hs, self.resource)
+ sync.register_servlets(self.hs, self.resource)
+
+ def test_simple_deny_mau(self):
+ # Create and sync so that the MAU counts get updated
+ token1 = self.create_user("kermit1")
+ self.do_sync_for_user(token1)
+ token2 = self.create_user("kermit2")
+ self.do_sync_for_user(token2)
+
+ # We've created and activated two users, we shouldn't be able to
+ # register new users
+ with self.assertRaises(SynapseError) as cm:
+ self.create_user("kermit3")
+
+ e = cm.exception
+ self.assertEqual(e.code, 403)
+ self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ def test_allowed_after_a_month_mau(self):
+ # Create and sync so that the MAU counts get updated
+ token1 = self.create_user("kermit1")
+ self.do_sync_for_user(token1)
+ token2 = self.create_user("kermit2")
+ self.do_sync_for_user(token2)
+
+ # Advance time by 31 days
+ self.reactor.advance(31 * 24 * 60 * 60)
+
+ self.store.reap_monthly_active_users()
+
+ self.reactor.advance(0)
+
+ # We should be able to register more users
+ token3 = self.create_user("kermit3")
+ self.do_sync_for_user(token3)
+
+ def test_trial_delay(self):
+ self.hs.config.mau_trial_days = 1
+
+ # We should be able to register more than the limit initially
+ token1 = self.create_user("kermit1")
+ self.do_sync_for_user(token1)
+ token2 = self.create_user("kermit2")
+ self.do_sync_for_user(token2)
+ token3 = self.create_user("kermit3")
+ self.do_sync_for_user(token3)
+
+ # Advance time by 2 days
+ self.reactor.advance(2 * 24 * 60 * 60)
+
+ # Two users should be able to sync
+ self.do_sync_for_user(token1)
+ self.do_sync_for_user(token2)
+
+ # But the third should fail
+ with self.assertRaises(SynapseError) as cm:
+ self.do_sync_for_user(token3)
+
+ e = cm.exception
+ self.assertEqual(e.code, 403)
+ self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ # And new registrations are now denied too
+ with self.assertRaises(SynapseError) as cm:
+ self.create_user("kermit4")
+
+ e = cm.exception
+ self.assertEqual(e.code, 403)
+ self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ def test_trial_users_cant_come_back(self):
+ self.hs.config.mau_trial_days = 1
+
+ # We should be able to register more than the limit initially
+ token1 = self.create_user("kermit1")
+ self.do_sync_for_user(token1)
+ token2 = self.create_user("kermit2")
+ self.do_sync_for_user(token2)
+ token3 = self.create_user("kermit3")
+ self.do_sync_for_user(token3)
+
+ # Advance time by 2 days
+ self.reactor.advance(2 * 24 * 60 * 60)
+
+ # Two users should be able to sync
+ self.do_sync_for_user(token1)
+ self.do_sync_for_user(token2)
+
+ # Advance by 2 months so everyone falls out of MAU
+ self.reactor.advance(60 * 24 * 60 * 60)
+ self.store.reap_monthly_active_users()
+ self.reactor.advance(0)
+
+ # We can create as many new users as we want
+ token4 = self.create_user("kermit4")
+ self.do_sync_for_user(token4)
+ token5 = self.create_user("kermit5")
+ self.do_sync_for_user(token5)
+ token6 = self.create_user("kermit6")
+ self.do_sync_for_user(token6)
+
+ # users 2 and 3 can come back to bring us back up to MAU limit
+ self.do_sync_for_user(token2)
+ self.do_sync_for_user(token3)
+
+ # New trial users can still sync
+ self.do_sync_for_user(token4)
+ self.do_sync_for_user(token5)
+ self.do_sync_for_user(token6)
+
+ # But old user cant
+ with self.assertRaises(SynapseError) as cm:
+ self.do_sync_for_user(token1)
+
+ e = cm.exception
+ self.assertEqual(e.code, 403)
+ self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ def create_user(self, localpart):
+ request_data = json.dumps(
+ {
+ "username": localpart,
+ "password": "monkey",
+ "auth": {"type": LoginType.DUMMY},
+ }
+ )
+
+ request, channel = make_request("POST", "/register", request_data)
+ render(request, self.resource, self.reactor)
+
+ if channel.code != 200:
+ raise HttpResponseException(
+ channel.code, channel.result["reason"], channel.result["body"]
+ ).to_synapse_error()
+
+ access_token = channel.json_body["access_token"]
+
+ return access_token
+
+ def do_sync_for_user(self, token):
+ request, channel = make_request(
+ "GET", "/sync", access_token=token.encode('ascii')
+ )
+ render(request, self.resource, self.reactor)
+
+ if channel.code != 200:
+ raise HttpResponseException(
+ channel.code, channel.result["reason"], channel.result["body"]
+ ).to_synapse_error()
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
new file mode 100644
index 0000000000..17897711a1
--- /dev/null
+++ b/tests/test_metrics.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.metrics import InFlightGauge
+
+from tests import unittest
+
+
+class TestMauLimit(unittest.TestCase):
+ def test_basic(self):
+ gauge = InFlightGauge(
+ "test1", "",
+ labels=["test_label"],
+ sub_metrics=["foo", "bar"],
+ )
+
+ def handle1(metrics):
+ metrics.foo += 2
+ metrics.bar = max(metrics.bar, 5)
+
+ def handle2(metrics):
+ metrics.foo += 3
+ metrics.bar = max(metrics.bar, 7)
+
+ gauge.register(("key1",), handle1)
+
+ self.assert_dict({
+ "test1_total": {("key1",): 1},
+ "test1_foo": {("key1",): 2},
+ "test1_bar": {("key1",): 5},
+ }, self.get_metrics_from_gauge(gauge))
+
+ gauge.unregister(("key1",), handle1)
+
+ self.assert_dict({
+ "test1_total": {("key1",): 0},
+ "test1_foo": {("key1",): 0},
+ "test1_bar": {("key1",): 0},
+ }, self.get_metrics_from_gauge(gauge))
+
+ gauge.register(("key1",), handle1)
+ gauge.register(("key2",), handle2)
+
+ self.assert_dict({
+ "test1_total": {("key1",): 1, ("key2",): 1},
+ "test1_foo": {("key1",): 2, ("key2",): 3},
+ "test1_bar": {("key1",): 5, ("key2",): 7},
+ }, self.get_metrics_from_gauge(gauge))
+
+ gauge.unregister(("key2",), handle2)
+ gauge.register(("key1",), handle2)
+
+ self.assert_dict({
+ "test1_total": {("key1",): 2, ("key2",): 0},
+ "test1_foo": {("key1",): 5, ("key2",): 0},
+ "test1_bar": {("key1",): 7, ("key2",): 0},
+ }, self.get_metrics_from_gauge(gauge))
+
+ def get_metrics_from_gauge(self, gauge):
+ results = {}
+
+ for r in gauge.collect():
+ results[r.name] = {
+ tuple(labels[x] for x in gauge.labels): value
+ for _, labels, value in r.samples
+ }
+
+ return results
diff --git a/tests/test_server.py b/tests/test_server.py
index ef74544e93..4045fdadc3 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -1,14 +1,35 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
import re
+from six import StringIO
+
from twisted.internet.defer import Deferred
-from twisted.test.proto_helpers import MemoryReactorClock
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource
+from synapse.http.site import SynapseSite, logger
from synapse.util import Clock
from tests import unittest
-from tests.server import make_request, render, setup_test_homeserver
+from tests.server import FakeTransport, make_request, render, setup_test_homeserver
class JsonResourceTests(unittest.TestCase):
@@ -121,3 +142,52 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.result["code"], b'400')
self.assertEqual(channel.json_body["error"], "Unrecognized request")
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+
+
+class SiteTestCase(unittest.HomeserverTestCase):
+ def test_lose_connection(self):
+ """
+ We log the URI correctly redacted when we lose the connection.
+ """
+
+ class HangingResource(Resource):
+ """
+ A Resource that strategically hangs, as if it were processing an
+ answer.
+ """
+
+ def render(self, request):
+ return NOT_DONE_YET
+
+ # Set up a logging handler that we can inspect afterwards
+ output = StringIO()
+ handler = logging.StreamHandler(output)
+ logger.addHandler(handler)
+ old_level = logger.level
+ logger.setLevel(10)
+ self.addCleanup(logger.setLevel, old_level)
+ self.addCleanup(logger.removeHandler, handler)
+
+ # Make a resource and a Site, the resource will hang and allow us to
+ # time out the request while it's 'processing'
+ base_resource = Resource()
+ base_resource.putChild(b'', HangingResource())
+ site = SynapseSite("test", "site_tag", {}, base_resource, "1.0")
+
+ server = site.buildProtocol(None)
+ client = AccumulatingProtocol()
+ client.makeConnection(FakeTransport(server, self.reactor))
+ server.makeConnection(FakeTransport(client, self.reactor))
+
+ # Send a request with an access token that will get redacted
+ server.dataReceived(b"GET /?access_token=bar HTTP/1.0\r\n\r\n")
+ self.pump()
+
+ # Lose the connection
+ e = Failure(Exception("Failed123"))
+ server.connectionLost(e)
+ handler.flush()
+
+ # Our access token is redacted and the failure reason is logged.
+ self.assertIn("/?access_token=<redacted>", output.getvalue())
+ self.assertIn("Failed123", output.getvalue())
diff --git a/tests/test_state.py b/tests/test_state.py
index 96fdb8636c..e20c33322a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -18,7 +18,7 @@ from mock import Mock
from twisted.internet import defer
from synapse.api.auth import Auth
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RoomVersions
from synapse.events import FrozenEvent
from synapse.state import StateHandler, StateResolutionHandler
@@ -117,6 +117,9 @@ class StateGroupStore(object):
def register_event_id_state_group(self, event_id, state_group):
self._event_to_state_group[event_id] = state_group
+ def get_room_version(self, room_id):
+ return RoomVersions.V1
+
class DictObj(dict):
def __init__(self, **kwargs):
@@ -176,7 +179,9 @@ class StateTestCase(unittest.TestCase):
def test_branch_no_conflict(self):
graph = Graph(
nodes={
- "START": DictObj(type=EventTypes.Create, state_key="", depth=1),
+ "START": DictObj(
+ type=EventTypes.Create, state_key="", content={}, depth=1
+ ),
"A": DictObj(type=EventTypes.Message, depth=2),
"B": DictObj(type=EventTypes.Message, depth=3),
"C": DictObj(type=EventTypes.Name, state_key="", depth=3),
diff --git a/tests/test_types.py b/tests/test_types.py
index be072d402b..0f5c8bfaf9 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -14,12 +14,12 @@
# limitations under the License.
from synapse.api.errors import SynapseError
-from synapse.server import HomeServer
from synapse.types import GroupID, RoomAlias, UserID
from tests import unittest
+from tests.utils import TestHomeServer
-mock_homeserver = HomeServer(hostname="my.domain")
+mock_homeserver = TestHomeServer(hostname="my.domain")
class UserIDTestCase(unittest.TestCase):
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 45a78338d6..2eea3b098b 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -21,7 +21,7 @@ from synapse.events import FrozenEvent
from synapse.visibility import filter_events_for_server
import tests.unittest
-from tests.utils import setup_test_homeserver
+from tests.utils import create_room, setup_test_homeserver
logger = logging.getLogger(__name__)
@@ -36,6 +36,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_builder_factory = self.hs.get_event_builder_factory()
self.store = self.hs.get_datastore()
+ yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
+
@defer.inlineCallbacks
def test_filtering(self):
#
@@ -94,7 +96,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
events_to_filter.append(evt)
# the erasey user gets erased
- self.hs.get_datastore().mark_user_erased("@erased:local_hs")
+ yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
# ... and the filtering happens.
filtered = yield filter_events_for_server(
diff --git a/tests/unittest.py b/tests/unittest.py
index d852e2465a..a59291cc60 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import hashlib
+import hmac
import logging
from mock import Mock
@@ -22,14 +24,17 @@ from canonicaljson import json
import twisted
import twisted.logger
+from twisted.internet.defer import Deferred
from twisted.trial import unittest
from synapse.http.server import JsonResource
+from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
from synapse.util.logcontext import LoggingContextFilter
from tests.server import get_clock, make_request, render, setup_test_homeserver
+from tests.utils import default_config
# Set up putting Synapse's logs into Trial's.
rootLogger = logging.getLogger()
@@ -151,6 +156,7 @@ class HomeserverTestCase(TestCase):
hijack_auth (bool): Whether to hijack auth to return the user specified
in user_id.
"""
+
servlets = []
hijack_auth = True
@@ -217,7 +223,17 @@ class HomeserverTestCase(TestCase):
Function to be overridden in subclasses.
"""
- raise NotImplementedError()
+ hs = self.setup_test_homeserver()
+ return hs
+
+ def default_config(self, name="test"):
+ """
+ Get a default HomeServer config object.
+
+ Args:
+ name (str): The homeserver name/domain.
+ """
+ return default_config(name)
def prepare(self, reactor, clock, homeserver):
"""
@@ -234,7 +250,9 @@ class HomeserverTestCase(TestCase):
Function to optionally be overridden in subclasses.
"""
- def make_request(self, method, path, content=b""):
+ def make_request(
+ self, method, path, content=b"", access_token=None, request=SynapseRequest
+ ):
"""
Create a SynapseRequest at the path using the method and containing the
given content.
@@ -252,7 +270,7 @@ class HomeserverTestCase(TestCase):
if isinstance(content, dict):
content = json.dumps(content).encode('utf8')
- return make_request(method, path, content)
+ return make_request(method, path, content, access_token, request)
def render(self, request):
"""
@@ -279,3 +297,81 @@ class HomeserverTestCase(TestCase):
kwargs = dict(kwargs)
kwargs.update(self._hs_args)
return setup_test_homeserver(self.addCleanup, *args, **kwargs)
+
+ def pump(self, by=0.0):
+ """
+ Pump the reactor enough that Deferreds will fire.
+ """
+ self.reactor.pump([by] * 100)
+
+ def get_success(self, d):
+ if not isinstance(d, Deferred):
+ return d
+ self.pump()
+ return self.successResultOf(d)
+
+ def register_user(self, username, password, admin=False):
+ """
+ Register a user. Requires the Admin API be registered.
+
+ Args:
+ username (bytes/unicode): The user part of the new user.
+ password (bytes/unicode): The password of the new user.
+ admin (bool): Whether the user should be created as an admin
+ or not.
+
+ Returns:
+ The MXID of the new user (unicode).
+ """
+ self.hs.config.registration_shared_secret = u"shared"
+
+ # Create the user
+ request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
+ self.render(request)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ nonce_str = b"\x00".join([username.encode('utf8'), password.encode('utf8')])
+ if admin:
+ nonce_str += b"\x00admin"
+ else:
+ nonce_str += b"\x00notadmin"
+ want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str)
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": username,
+ "password": password,
+ "admin": admin,
+ "mac": want_mac,
+ }
+ )
+ request, channel = self.make_request(
+ "POST", "/_matrix/client/r0/admin/register", body.encode('utf8')
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ user_id = channel.json_body["user_id"]
+ return user_id
+
+ def login(self, username, password, device_id=None):
+ """
+ Log in a user, and get an access token. Requires the Login API be
+ registered.
+
+ """
+ body = {"type": "m.login.password", "user": username, "password": password}
+ if device_id:
+ body["device_id"] = device_id
+
+ request, channel = self.make_request(
+ "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8')
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200)
+
+ access_token = channel.json_body["access_token"].encode('ascii')
+ return access_token
diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py
index 5cbada4eda..50bc7702d2 100644
--- a/tests/util/test_expiring_cache.py
+++ b/tests/util/test_expiring_cache.py
@@ -65,7 +65,6 @@ class ExpiringCacheTestCase(unittest.TestCase):
def test_time_eviction(self):
clock = MockClock()
cache = ExpiringCache("test", clock, expiry_ms=1000)
- cache.start()
cache["key"] = 1
clock.advance_time(0.5)
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 4633db77b3..8adaee3c8d 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -159,6 +159,11 @@ class LoggingContextTestCase(unittest.TestCase):
self.assertEqual(r, "bum")
self._check_test_key("one")
+ def test_nested_logging_context(self):
+ with LoggingContext(request="foo"):
+ nested_context = logcontext.nested_logging_context(suffix="bar")
+ self.assertEqual(nested_context.request, "foo-bar")
+
# a function which returns a deferred which has been "called", but
# which had a function which returned another incomplete deferred on
diff --git a/tests/utils.py b/tests/utils.py
index f1683e7a06..022a868501 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -16,7 +16,9 @@
import atexit
import hashlib
import os
+import time
import uuid
+import warnings
from inspect import getcallargs
from mock import Mock, patch
@@ -24,12 +26,14 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer, reactor
+from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
+from synapse.config.server import ServerConfig
from synapse.federation.transport import server
from synapse.http.server import HttpServer
from synapse.server import HomeServer
-from synapse.storage import PostgresEngine
-from synapse.storage.engines import create_engine
+from synapse.storage import DataStore
+from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import (
_get_or_create_schema_state,
_setup_new_database,
@@ -40,6 +44,7 @@ from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
+LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres")
POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
@@ -91,10 +96,79 @@ def setupdb():
atexit.register(_cleanup)
+def default_config(name):
+ """
+ Create a reasonable test config.
+ """
+ config = Mock()
+ config.signing_key = [MockKey()]
+ config.event_cache_size = 1
+ config.enable_registration = True
+ config.macaroon_secret_key = "not even a little secret"
+ config.expire_access_token = False
+ config.server_name = name
+ config.trusted_third_party_id_servers = []
+ config.room_invite_state_types = []
+ config.password_providers = []
+ config.worker_replication_url = ""
+ config.worker_app = None
+ config.email_enable_notifs = False
+ config.block_non_admin_invites = False
+ config.federation_domain_whitelist = None
+ config.federation_rc_reject_limit = 10
+ config.federation_rc_sleep_limit = 10
+ config.federation_rc_sleep_delay = 100
+ config.federation_rc_concurrent = 10
+ config.filter_timeline_limit = 5000
+ config.user_directory_search_all_users = False
+ config.replicate_user_profiles_to = []
+ config.user_consent_server_notice_content = None
+ config.block_events_without_consent_error = None
+ config.media_storage_providers = []
+ config.autocreate_auto_join_rooms = True
+ config.auto_join_rooms = []
+ config.limit_usage_by_mau = False
+ config.hs_disabled = False
+ config.hs_disabled_message = ""
+ config.hs_disabled_limit_type = ""
+ config.max_mau_value = 50
+ config.mau_trial_days = 0
+ config.mau_limits_reserved_threepids = []
+ config.admin_contact = None
+ config.rc_messages_per_second = 10000
+ config.rc_message_burst_count = 10000
+
+ config.use_frozen_dicts = False
+
+ # we need a sane default_room_version, otherwise attempts to create rooms will
+ # fail.
+ config.default_room_version = "1"
+
+ # disable user directory updates, because they get done in the
+ # background, which upsets the test runner.
+ config.update_user_directory = False
+
+ def is_threepid_reserved(threepid):
+ return ServerConfig.is_threepid_reserved(config, threepid)
+
+ config.is_threepid_reserved.side_effect = is_threepid_reserved
+
+ return config
+
+
+class TestHomeServer(HomeServer):
+ DATASTORE_CLASS = DataStore
+
+
@defer.inlineCallbacks
def setup_test_homeserver(
- cleanup_func, name="test", datastore=None, config=None, reactor=None,
- homeserverToUse=HomeServer, **kargs
+ cleanup_func,
+ name="test",
+ datastore=None,
+ config=None,
+ reactor=None,
+ homeserverToUse=TestHomeServer,
+ **kargs
):
"""
Setup a homeserver suitable for running tests against. Keyword arguments
@@ -110,49 +184,8 @@ def setup_test_homeserver(
from twisted.internet import reactor
if config is None:
- config = Mock()
- config.signing_key = [MockKey()]
- config.event_cache_size = 1
- config.enable_registration = True
- config.macaroon_secret_key = "not even a little secret"
- config.expire_access_token = False
- config.server_name = name
- config.trusted_third_party_id_servers = []
- config.room_invite_state_types = []
- config.password_providers = []
- config.worker_replication_url = ""
- config.worker_app = None
- config.email_enable_notifs = False
- config.block_non_admin_invites = False
- config.federation_domain_whitelist = None
- config.federation_rc_reject_limit = 10
- config.federation_rc_sleep_limit = 10
- config.federation_rc_sleep_delay = 100
- config.federation_rc_concurrent = 10
- config.filter_timeline_limit = 5000
- config.user_directory_search_all_users = False
- config.replicate_user_profiles_to = []
- config.user_consent_server_notice_content = None
- config.block_events_without_consent_error = None
- config.media_storage_providers = []
- config.auto_join_rooms = []
- config.limit_usage_by_mau = False
- config.hs_disabled = False
- config.hs_disabled_message = ""
- config.hs_disabled_limit_type = ""
- config.max_mau_value = 50
- config.mau_limits_reserved_threepids = []
- config.admin_uri = None
-
- # we need a sane default_room_version, otherwise attempts to create rooms will
- # fail.
- config.default_room_version = "1"
-
- # disable user directory updates, because they get done in the
- # background, which upsets the test runner.
- config.update_user_directory = False
-
- config.use_frozen_dicts = True
+ config = default_config(name)
+
config.ldap_enabled = False
if "clock" not in kargs:
@@ -218,22 +251,44 @@ def setup_test_homeserver(
else:
# We need to do cleanup on PostgreSQL
def cleanup():
+ import psycopg2
+
# Close all the db pools
hs.get_db_pool().close()
+ dropped = False
+
# Drop the test database
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB, user=POSTGRES_USER
)
db_conn.autocommit = True
cur = db_conn.cursor()
- cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
- db_conn.commit()
+
+ # Try a few times to drop the DB. Some things may hold on to the
+ # database for a few more seconds due to flakiness, preventing
+ # us from dropping it when the test is over. If we can't drop
+ # it, warn and move on.
+ for x in range(5):
+ try:
+ cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+ db_conn.commit()
+ dropped = True
+ except psycopg2.OperationalError as e:
+ warnings.warn(
+ "Couldn't drop old db: " + str(e), category=UserWarning
+ )
+ time.sleep(0.5)
+
cur.close()
db_conn.close()
- # Register the cleanup hook
- cleanup_func(cleanup)
+ if not dropped:
+ warnings.warn("Failed to drop old DB.", category=UserWarning)
+
+ if not LEAVE_DB:
+ # Register the cleanup hook
+ cleanup_func(cleanup)
hs.setup()
else:
@@ -307,7 +362,9 @@ class MockHttpResource(HttpServer):
@patch('twisted.web.http.Request')
@defer.inlineCallbacks
- def trigger(self, http_method, path, content, mock_request, federation_auth=False):
+ def trigger(
+ self, http_method, path, content, mock_request, federation_auth_origin=None
+ ):
""" Fire an HTTP event.
Args:
@@ -316,6 +373,7 @@ class MockHttpResource(HttpServer):
content : The HTTP body
mock_request : Mocked request to pass to the event so it can get
content.
+ federation_auth_origin (bytes|None): domain to authenticate as, for federation
Returns:
A tuple of (code, response)
Raises:
@@ -336,8 +394,10 @@ class MockHttpResource(HttpServer):
mock_request.getClientIP.return_value = "-"
headers = {}
- if federation_auth:
- headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
+ if federation_auth_origin is not None:
+ headers[b"Authorization"] = [
+ b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
+ ]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it
@@ -540,3 +600,32 @@ class DeferredMockCallable(object):
"Expected not to received any calls, got:\n"
+ "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
)
+
+
+@defer.inlineCallbacks
+def create_room(hs, room_id, creator_id):
+ """Creates and persist a creation event for the given room
+
+ Args:
+ hs
+ room_id (str)
+ creator_id (str)
+ """
+
+ store = hs.get_datastore()
+ event_builder_factory = hs.get_event_builder_factory()
+ event_creation_handler = hs.get_event_creation_handler()
+
+ builder = event_builder_factory.new(
+ {
+ "type": EventTypes.Create,
+ "state_key": "",
+ "sender": creator_id,
+ "room_id": room_id,
+ "content": {},
+ }
+ )
+
+ event, context = yield event_creation_handler.create_new_client_event(builder)
+
+ yield store.persist_event(event, context)
|