diff options
author | Michael Telatynski <7t3chguy@gmail.com> | 2018-07-24 17:17:46 +0100 |
---|---|---|
committer | Michael Telatynski <7t3chguy@gmail.com> | 2018-07-24 17:17:46 +0100 |
commit | 87951d3891efb5bccedf72c12b3da0d6ab482253 (patch) | |
tree | de7d997567c66c5a4d8743c1f3b9d6b474f5cfd9 /tests | |
parent | if inviter_display_name == ""||None then default to inviter MXID (diff) | |
parent | Merge pull request #3595 from matrix-org/erikj/use_deltas (diff) | |
download | synapse-87951d3891efb5bccedf72c12b3da0d6ab482253.tar.xz |
Merge branch 'develop' of github.com:matrix-org/synapse into t3chguy/default_inviter_display_name_3pid
Diffstat (limited to 'tests')
89 files changed, 4515 insertions, 2059 deletions
diff --git a/tests/__init__.py b/tests/__init__.py index bfebb0f644..24006c949e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -12,3 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from twisted.trial import util + +util.DEFAULT_TIMEOUT_DURATION = 10 diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 4575dd9834..5f158ec4b9 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -13,16 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pymacaroons from mock import Mock + +import pymacaroons + from twisted.internet import defer import synapse.handlers.auth from synapse.api.auth import Auth from synapse.api.errors import AuthError from synapse.types import UserID + from tests import unittest -from tests.utils import setup_test_homeserver, mock_getRawHeaders +from tests.utils import mock_getRawHeaders, setup_test_homeserver class TestHandlers(object): @@ -86,16 +89,53 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token(self): - app_service = Mock(token="foobar", url="a_url", sender=self.test_user) + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=None, + ) + self.store.get_app_service_by_token = Mock(return_value=app_service) + self.store.get_user_by_access_token = Mock(return_value=None) + + request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" + request.args["access_token"] = [self.test_token] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + requester = yield self.auth.get_user_by_req(request) + self.assertEquals(requester.user.to_string(), self.test_user) + + @defer.inlineCallbacks + def test_get_user_by_req_appservice_valid_token_good_ip(self): + from netaddr import IPSet + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=IPSet(["192.168/16"]), + ) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) + request.getClientIP.return_value = "192.168.10.10" request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) + def test_get_user_by_req_appservice_valid_token_bad_ip(self): + from netaddr import IPSet + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=IPSet(["192.168/16"]), + ) + self.store.get_app_service_by_token = Mock(return_value=app_service) + self.store.get_user_by_access_token = Mock(return_value=None) + + request = Mock(args={}) + request.getClientIP.return_value = "131.111.8.42" + request.args["access_token"] = [self.test_token] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + d = self.auth.get_user_by_req(request) + self.failureResultOf(d, AuthError) + def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None) @@ -119,12 +159,16 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token_valid_user_id(self): masquerading_user_id = "@doppelganger:matrix.org" - app_service = Mock(token="foobar", url="a_url", sender=self.test_user) + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=None, + ) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -133,12 +177,16 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_valid_token_bad_user_id(self): masquerading_user_id = "@doppelganger:matrix.org" - app_service = Mock(token="foobar", url="a_url", sender=self.test_user) + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=None, + ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index dcceca7f3e..836a23fb54 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -13,19 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests import unittest -from twisted.internet import defer - from mock import Mock -from tests.utils import ( - MockHttpResource, DeferredMockCallable, setup_test_homeserver -) +import jsonschema + +from twisted.internet import defer + +from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events import FrozenEvent -from synapse.api.errors import SynapseError -import jsonschema +from tests import unittest +from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver user_localpart = "test_user" diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 7586ea9053..891e0cc973 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -12,14 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.appservice import ApplicationService +import re + +from mock import Mock from twisted.internet import defer -from mock import Mock -from tests import unittest +from synapse.appservice import ApplicationService -import re +from tests import unittest def _regex(regex, exclusive=True): @@ -36,6 +37,7 @@ class ApplicationServiceTestCase(unittest.TestCase): id="unique_identifier", url="some_url", token="some_token", + hostname="matrix.org", # only used by get_groups_for_user namespaces={ ApplicationService.NS_USERS: [], ApplicationService.NS_ROOMS: [], diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index e5a902f734..b9f4863e9a 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -12,15 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + +from twisted.internet import defer + from synapse.appservice import ApplicationServiceState from synapse.appservice.scheduler import ( - _ServiceQueuer, _TransactionController, _Recoverer + _Recoverer, + _ServiceQueuer, + _TransactionController, ) -from twisted.internet import defer -from ..utils import MockClock -from mock import Mock +from synapse.util.logcontext import make_deferred_yieldable + from tests import unittest +from ..utils import MockClock + class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): @@ -204,7 +211,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): def test_send_single_event_with_queue(self): d = defer.Deferred() - self.txn_ctrl.send = Mock(return_value=d) + self.txn_ctrl.send = Mock( + side_effect=lambda x, y: make_deferred_yieldable(d), + ) service = Mock(id=4) event = Mock(event_id="first") event2 = Mock(event_id="second") @@ -235,7 +244,10 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): srv_2_event2 = Mock(event_id="srv2b") send_return_list = [srv_1_defer, srv_2_defer] - self.txn_ctrl.send = Mock(side_effect=lambda x, y: send_return_list.pop(0)) + + def do_send(x, y): + return make_deferred_yieldable(send_return_list.pop(0)) + self.txn_ctrl.send = Mock(side_effect=do_send) # send events for different ASes and make sure they are sent self.queuer.enqueue(srv1, srv_1_event) diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index 8f57fbeb23..eb7f0ab12a 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -12,10 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os.path +import re import shutil import tempfile + from synapse.config.homeserver import HomeServerConfig + from tests import unittest @@ -23,7 +27,6 @@ class ConfigGenerationTestCase(unittest.TestCase): def setUp(self): self.dir = tempfile.mkdtemp() - print self.dir self.file = os.path.join(self.dir, "homeserver.yaml") def tearDown(self): @@ -48,3 +51,16 @@ class ConfigGenerationTestCase(unittest.TestCase): ]), set(os.listdir(self.dir)) ) + + self.assert_log_filename_is( + os.path.join(self.dir, "lemurs.win.log.config"), + os.path.join(os.getcwd(), "homeserver.log"), + ) + + def assert_log_filename_is(self, log_config_file, expected): + with open(log_config_file) as f: + config = f.read() + # find the 'filename' line + matches = re.findall("^\s*filename:\s*(.*)$", config, re.M) + self.assertEqual(1, len(matches)) + self.assertEqual(matches[0], expected) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 161a87d7e3..5c422eff38 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -15,8 +15,11 @@ import os.path import shutil import tempfile + import yaml + from synapse.config.homeserver import HomeServerConfig + from tests import unittest @@ -24,7 +27,7 @@ class ConfigLoadingTestCase(unittest.TestCase): def setUp(self): self.dir = tempfile.mkdtemp() - print self.dir + print(self.dir) self.file = os.path.join(self.dir, "homeserver.yaml") def tearDown(self): diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 47cb328a01..cd11871b80 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -14,15 +14,13 @@ # limitations under the License. -from tests import unittest - -from synapse.events.builder import EventBuilder -from synapse.crypto.event_signing import add_hashes_and_signatures - +import nacl.signing from unpaddedbase64 import decode_base64 -import nacl.signing +from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.events.builder import EventBuilder +from tests import unittest # Perform these tests using given secret key so we get entirely deterministic # signatures output that we can test against. diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py new file mode 100644 index 0000000000..a9d37fe084 --- /dev/null +++ b/tests/crypto/test_keyring.py @@ -0,0 +1,234 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +from mock import Mock + +import signedjson.key +import signedjson.sign + +from twisted.internet import defer, reactor + +from synapse.api.errors import SynapseError +from synapse.crypto import keyring +from synapse.util import Clock, logcontext +from synapse.util.logcontext import LoggingContext + +from tests import unittest, utils + + +class MockPerspectiveServer(object): + def __init__(self): + self.server_name = "mock_server" + self.key = signedjson.key.generate_signing_key(0) + + def get_verify_keys(self): + vk = signedjson.key.get_verify_key(self.key) + return { + "%s:%s" % (vk.alg, vk.version): vk, + } + + def get_signed_key(self, server_name, verify_key): + key_id = "%s:%s" % (verify_key.alg, verify_key.version) + res = { + "server_name": server_name, + "old_verify_keys": {}, + "valid_until_ts": time.time() * 1000 + 3600, + "verify_keys": { + key_id: { + "key": signedjson.key.encode_verify_key_base64(verify_key) + } + } + } + signedjson.sign.sign_json(res, self.server_name, self.key) + return res + + +class KeyringTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.mock_perspective_server = MockPerspectiveServer() + self.http_client = Mock() + self.hs = yield utils.setup_test_homeserver( + handlers=None, + http_client=self.http_client, + ) + self.hs.config.perspectives = { + self.mock_perspective_server.server_name: + self.mock_perspective_server.get_verify_keys() + } + + def check_context(self, _, expected): + self.assertEquals( + getattr(LoggingContext.current_context(), "request", None), + expected + ) + + @defer.inlineCallbacks + def test_wait_for_previous_lookups(self): + sentinel_context = LoggingContext.current_context() + + kr = keyring.Keyring(self.hs) + + lookup_1_deferred = defer.Deferred() + lookup_2_deferred = defer.Deferred() + + with LoggingContext("one") as context_one: + context_one.request = "one" + + wait_1_deferred = kr.wait_for_previous_lookups( + ["server1"], + {"server1": lookup_1_deferred}, + ) + + # there were no previous lookups, so the deferred should be ready + self.assertTrue(wait_1_deferred.called) + # ... so we should have preserved the LoggingContext. + self.assertIs(LoggingContext.current_context(), context_one) + wait_1_deferred.addBoth(self.check_context, "one") + + with LoggingContext("two") as context_two: + context_two.request = "two" + + # set off another wait. It should block because the first lookup + # hasn't yet completed. + wait_2_deferred = kr.wait_for_previous_lookups( + ["server1"], + {"server1": lookup_2_deferred}, + ) + self.assertFalse(wait_2_deferred.called) + # ... so we should have reset the LoggingContext. + self.assertIs(LoggingContext.current_context(), sentinel_context) + wait_2_deferred.addBoth(self.check_context, "two") + + # let the first lookup complete (in the sentinel context) + lookup_1_deferred.callback(None) + + # now the second wait should complete and restore our + # loggingcontext. + yield wait_2_deferred + + @defer.inlineCallbacks + def test_verify_json_objects_for_server_awaits_previous_requests(self): + clock = Clock(reactor) + key1 = signedjson.key.generate_signing_key(1) + + kr = keyring.Keyring(self.hs) + json1 = {} + signedjson.sign.sign_json(json1, "server10", key1) + + persp_resp = { + "server_keys": [ + self.mock_perspective_server.get_signed_key( + "server10", + signedjson.key.get_verify_key(key1) + ), + ] + } + persp_deferred = defer.Deferred() + + @defer.inlineCallbacks + def get_perspectives(**kwargs): + self.assertEquals( + LoggingContext.current_context().request, "11", + ) + with logcontext.PreserveLoggingContext(): + yield persp_deferred + defer.returnValue(persp_resp) + self.http_client.post_json.side_effect = get_perspectives + + with LoggingContext("11") as context_11: + context_11.request = "11" + + # start off a first set of lookups + res_deferreds = kr.verify_json_objects_for_server( + [("server10", json1), + ("server11", {}) + ] + ) + + # the unsigned json should be rejected pretty quickly + self.assertTrue(res_deferreds[1].called) + try: + yield res_deferreds[1] + self.assertFalse("unsigned json didn't cause a failure") + except SynapseError: + pass + + self.assertFalse(res_deferreds[0].called) + res_deferreds[0].addBoth(self.check_context, None) + + # wait a tick for it to send the request to the perspectives server + # (it first tries the datastore) + yield clock.sleep(1) # XXX find out why this takes so long! + self.http_client.post_json.assert_called_once() + + self.assertIs(LoggingContext.current_context(), context_11) + + context_12 = LoggingContext("12") + context_12.request = "12" + with logcontext.PreserveLoggingContext(context_12): + # a second request for a server with outstanding requests + # should block rather than start a second call + self.http_client.post_json.reset_mock() + self.http_client.post_json.return_value = defer.Deferred() + + res_deferreds_2 = kr.verify_json_objects_for_server( + [("server10", json1)], + ) + yield clock.sleep(1) + self.http_client.post_json.assert_not_called() + res_deferreds_2[0].addBoth(self.check_context, None) + + # complete the first request + with logcontext.PreserveLoggingContext(): + persp_deferred.callback(persp_resp) + self.assertIs(LoggingContext.current_context(), context_11) + + with logcontext.PreserveLoggingContext(): + yield res_deferreds[0] + yield res_deferreds_2[0] + + @defer.inlineCallbacks + def test_verify_json_for_server(self): + kr = keyring.Keyring(self.hs) + + key1 = signedjson.key.generate_signing_key(1) + yield self.hs.datastore.store_server_verify_key( + "server9", "", time.time() * 1000, + signedjson.key.get_verify_key(key1), + ) + json1 = {} + signedjson.sign.sign_json(json1, "server9", key1) + + sentinel_context = LoggingContext.current_context() + + with LoggingContext("one") as context_one: + context_one.request = "one" + + defer = kr.verify_json_for_server("server9", {}) + try: + yield defer + self.fail("should fail on unsigned json") + except SynapseError: + pass + self.assertIs(LoggingContext.current_context(), context_one) + + defer = kr.verify_json_for_server("server9", json1) + self.assertFalse(defer.called) + self.assertIs(LoggingContext.current_context(), sentinel_context) + yield defer + + self.assertIs(LoggingContext.current_context(), context_one) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index dfc870066e..f51d99419e 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -14,11 +14,11 @@ # limitations under the License. -from .. import unittest - from synapse.events import FrozenEvent from synapse.events.utils import prune_event, serialize_event +from .. import unittest + def MockEvent(**kwargs): if "event_id" not in kwargs: diff --git a/tests/metrics/__init__.py b/tests/federation/__init__.py index e69de29bb2..e69de29bb2 100644 --- a/tests/metrics/__init__.py +++ b/tests/federation/__init__.py diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py new file mode 100644 index 0000000000..c91e25f54f --- /dev/null +++ b/tests/federation/test_federation_server.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.events import FrozenEvent +from synapse.federation.federation_server import server_matches_acl_event + +from tests import unittest + + +@unittest.DEBUG +class ServerACLsTestCase(unittest.TestCase): + def test_blacklisted_server(self): + e = _create_acl_event({ + "allow": ["*"], + "deny": ["evil.com"], + }) + logging.info("ACL event: %s", e.content) + + self.assertFalse(server_matches_acl_event("evil.com", e)) + self.assertFalse(server_matches_acl_event("EVIL.COM", e)) + + self.assertTrue(server_matches_acl_event("evil.com.au", e)) + self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e)) + + def test_block_ip_literals(self): + e = _create_acl_event({ + "allow_ip_literals": False, + "allow": ["*"], + }) + logging.info("ACL event: %s", e.content) + + self.assertFalse(server_matches_acl_event("1.2.3.4", e)) + self.assertTrue(server_matches_acl_event("1a.2.3.4", e)) + self.assertFalse(server_matches_acl_event("[1:2::]", e)) + self.assertTrue(server_matches_acl_event("1:2:3:4", e)) + + +def _create_acl_event(content): + return FrozenEvent({ + "room_id": "!a:b", + "event_id": "$a:b", + "type": "m.room.server_acls", + "sender": "@a:b", + "content": content + }) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 7fe88172c0..57c0771cf3 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from twisted.internet import defer -from .. import unittest -from tests.utils import MockClock from synapse.handlers.appservice import ApplicationServicesHandler -from mock import Mock +from tests.utils import MockClock + +from .. import unittest class AppServiceHandlerTestCase(unittest.TestCase): @@ -31,6 +33,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_scheduler = Mock() hs = Mock() hs.get_datastore = Mock(return_value=self.mock_store) + self.mock_store.get_received_ts.return_value = 0 hs.get_application_service_api = Mock(return_value=self.mock_as_api) hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler) hs.get_clock.return_value = MockClock() @@ -53,7 +56,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): type="m.room.message", room_id="!foo:bar" ) - self.mock_store.get_new_events_for_appservice.return_value = (0, [event]) + self.mock_store.get_new_events_for_appservice.side_effect = [ + (0, [event]), + (0, []) + ] self.mock_as_api.push = Mock() yield self.handler.notify_interested_services(0) self.mock_scheduler.submit_event_for_as.assert_called_once_with( @@ -75,7 +81,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): ) self.mock_as_api.push = Mock() self.mock_as_api.query_user = Mock() - self.mock_store.get_new_events_for_appservice.return_value = (0, [event]) + self.mock_store.get_new_events_for_appservice.side_effect = [ + (0, [event]), + (0, []) + ] yield self.handler.notify_interested_services(0) self.mock_as_api.query_user.assert_called_once_with( services[0], user_id @@ -98,7 +107,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): ) self.mock_as_api.push = Mock() self.mock_as_api.query_user = Mock() - self.mock_store.get_new_events_for_appservice.return_value = (0, [event]) + self.mock_store.get_new_events_for_appservice.side_effect = [ + (0, [event]), + (0, []) + ] yield self.handler.notify_interested_services(0) self.assertFalse( self.mock_as_api.query_user.called, diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 1822dcf1e0..2e5e8e4dec 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -14,11 +14,13 @@ # limitations under the License. import pymacaroons + from twisted.internet import defer import synapse import synapse.api.errors from synapse.handlers.auth import AuthHandler + from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 2eaaa8253c..633a0b7f36 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -17,9 +17,8 @@ from twisted.internet import defer import synapse.api.errors import synapse.handlers.device - import synapse.storage -from synapse import types + from tests import unittest, utils user1 = "@boris:aaa" @@ -179,6 +178,6 @@ class DeviceTestCase(unittest.TestCase): if ip is not None: yield self.store.insert_client_ip( - types.UserID.from_string(user_id), + user_id, access_token, ip, "user_agent", device_id) self.clock.advance_time(1000) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 5712773909..a353070316 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -14,14 +14,14 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer - from mock import Mock +from twisted.internet import defer + from synapse.handlers.directory import DirectoryHandler from synapse.types import RoomAlias +from tests import unittest from tests.utils import setup_test_homeserver @@ -35,21 +35,20 @@ class DirectoryTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.mock_federation = Mock(spec=[ - "make_query", - "register_edu_handler", - ]) + self.mock_federation = Mock() + self.mock_registry = Mock() self.query_handlers = {} def register_query_handler(query_type, handler): self.query_handlers[query_type] = handler - self.mock_federation.register_query_handler = register_query_handler + self.mock_registry.register_query_handler = register_query_handler hs = yield setup_test_homeserver( http_client=None, resource_for_federation=Mock(), - replication_layer=self.mock_federation, + federation_client=self.mock_federation, + federation_registry=self.mock_registry, ) hs.handlers = DirectoryHandlers(hs) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 19f5ed6bce..ca1542236d 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -14,13 +14,14 @@ # limitations under the License. import mock -from synapse.api import errors + from twisted.internet import defer import synapse.api.errors import synapse.handlers.e2e_keys - import synapse.storage +from synapse.api import errors + from tests import unittest, utils @@ -34,7 +35,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): def setUp(self): self.hs = yield utils.setup_test_homeserver( handlers=None, - replication_layer=mock.Mock(), + federation_client=mock.Mock(), ) self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) @@ -143,7 +144,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase): except errors.SynapseError: pass - @unittest.DEBUG @defer.inlineCallbacks def test_claim_one_time_key(self): local_user = "@boris:" + self.hs.hostname diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index de06a6ad30..121ce78634 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -14,18 +14,22 @@ # limitations under the License. -from tests import unittest - from mock import Mock, call from synapse.api.constants import PresenceState from synapse.handlers.presence import ( - handle_update, handle_timeout, - IDLE_TIMER, SYNC_ONLINE_TIMEOUT, LAST_ACTIVE_GRANULARITY, FEDERATION_TIMEOUT, FEDERATION_PING_INTERVAL, + FEDERATION_TIMEOUT, + IDLE_TIMER, + LAST_ACTIVE_GRANULARITY, + SYNC_ONLINE_TIMEOUT, + handle_timeout, + handle_update, ) from synapse.storage.presence import UserPresenceState +from tests import unittest + class PresenceUpdateTestCase(unittest.TestCase): def test_offline_to_online(self): diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 2a203129ca..dc17918a3d 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,16 +14,16 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer - from mock import Mock, NonCallableMock +from twisted.internet import defer + import synapse.types from synapse.api.errors import AuthError from synapse.handlers.profile import ProfileHandler from synapse.types import UserID +from tests import unittest from tests.utils import setup_test_homeserver @@ -37,23 +37,23 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.mock_federation = Mock(spec=[ - "make_query", - "register_edu_handler", - ]) + self.mock_federation = Mock() + self.mock_registry = Mock() self.query_handlers = {} def register_query_handler(query_type, handler): self.query_handlers[query_type] = handler - self.mock_federation.register_query_handler = register_query_handler + self.mock_registry.register_query_handler = register_query_handler hs = yield setup_test_homeserver( http_client=None, handlers=None, resource_for_federation=Mock(), - replication_layer=self.mock_federation, + federation_client=self.mock_federation, + federation_server=Mock(), + federation_registry=self.mock_registry, ratelimiter=NonCallableMock(spec_set=[ "send_message", ]) @@ -62,8 +62,6 @@ class ProfileTestCase(unittest.TestCase): self.ratelimiter = hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) - hs.handlers = ProfileHandlers(hs) - self.store = hs.get_datastore() self.frank = UserID.from_string("@1234ABCD:test") @@ -72,7 +70,7 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile(self.frank.localpart) - self.handler = hs.get_handlers().profile_handler + self.handler = hs.get_profile_handler() @defer.inlineCallbacks def test_get_my_name(self): diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index c8cf9a63ec..025fa1be81 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from twisted.internet import defer -from .. import unittest from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester from tests.utils import setup_test_homeserver -from mock import Mock +from .. import unittest class RegistrationHandlers(object): @@ -40,13 +41,14 @@ class RegistrationTestCase(unittest.TestCase): self.hs = yield setup_test_homeserver( handlers=None, http_client=None, - expire_access_token=True) + expire_access_token=True, + profile_handler=Mock(), + ) self.macaroon_generator = Mock( generate_access_token=Mock(return_value='secret')) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler - self.hs.get_handlers().profile_handler = Mock() @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index dbe50383da..b08856f763 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -14,19 +14,24 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer - -from mock import Mock, call, ANY import json -from ..utils import ( - MockHttpResource, MockClock, DeferredMockCallable, setup_test_homeserver -) +from mock import ANY, Mock, call + +from twisted.internet import defer from synapse.api.errors import AuthError from synapse.types import UserID +from tests import unittest + +from ..utils import ( + DeferredMockCallable, + MockClock, + MockHttpResource, + setup_test_homeserver, +) + def _expect_edu(destination, edu_type, content, origin="test"): return { @@ -58,7 +63,7 @@ class TypingNotificationsTestCase(unittest.TestCase): self.mock_federation_resource = MockHttpResource() - mock_notifier = Mock(spec=["on_new_event"]) + mock_notifier = Mock() self.on_new_event = mock_notifier.on_new_event self.auth = Mock(spec=[]) @@ -76,9 +81,12 @@ class TypingNotificationsTestCase(unittest.TestCase): "set_received_txn_response", "get_destination_retry_timings", "get_devices_by_remote", + # Bits that user_directory needs + "get_user_directory_stream_pos", + "get_current_state_deltas", ]), state_handler=self.state_handler, - handlers=None, + handlers=Mock(), notifier=mock_notifier, resource_for_client=Mock(), resource_for_federation=self.mock_federation_resource, @@ -122,6 +130,15 @@ class TypingNotificationsTestCase(unittest.TestCase): return set(str(u) for u in self.room_members) self.state_handler.get_current_user_in_room = get_current_user_in_room + self.datastore.get_user_directory_stream_pos.return_value = ( + # we deliberately return a non-None stream pos to avoid doing an initial_spam + defer.succeed(1) + ) + + self.datastore.get_current_state_deltas.return_value = ( + None + ) + self.auth.check_joined_room = check_joined_room self.datastore.get_to_device_stream_token = lambda: 0 diff --git a/tests/http/__init__.py b/tests/http/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/http/__init__.py diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py new file mode 100644 index 0000000000..60e6a75953 --- /dev/null +++ b/tests/http/test_endpoint.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name + +from tests import unittest + + +class ServerNameTestCase(unittest.TestCase): + def test_parse_server_name(self): + test_data = { + 'localhost': ('localhost', None), + 'my-example.com:1234': ('my-example.com', 1234), + '1.2.3.4': ('1.2.3.4', None), + '[0abc:1def::1234]': ('[0abc:1def::1234]', None), + '1.2.3.4:1': ('1.2.3.4', 1), + '[0abc:1def::1234]:8080': ('[0abc:1def::1234]', 8080), + } + + for i, o in test_data.items(): + self.assertEqual(parse_server_name(i), o) + + def test_validate_bad_server_names(self): + test_data = [ + "", # empty + "localhost:http", # non-numeric port + "1234]", # smells like ipv6 literal but isn't + "[1234", + "underscore_.com", + "percent%65.com", + "1234:5678:80", # too many colons + ] + for i in test_data: + try: + parse_and_validate_server_name(i) + self.fail( + "Expected parse_and_validate_server_name('%s') to throw" % ( + i, + ), + ) + except ValueError: + pass diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py deleted file mode 100644 index f85455a5af..0000000000 --- a/tests/metrics/test_metric.py +++ /dev/null @@ -1,161 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from tests import unittest - -from synapse.metrics.metric import ( - CounterMetric, CallbackMetric, DistributionMetric, CacheMetric -) - - -class CounterMetricTestCase(unittest.TestCase): - - def test_scalar(self): - counter = CounterMetric("scalar") - - self.assertEquals(counter.render(), [ - 'scalar 0', - ]) - - counter.inc() - - self.assertEquals(counter.render(), [ - 'scalar 1', - ]) - - counter.inc_by(2) - - self.assertEquals(counter.render(), [ - 'scalar 3' - ]) - - def test_vector(self): - counter = CounterMetric("vector", labels=["method"]) - - # Empty counter doesn't yet know what values it has - self.assertEquals(counter.render(), []) - - counter.inc("GET") - - self.assertEquals(counter.render(), [ - 'vector{method="GET"} 1', - ]) - - counter.inc("GET") - counter.inc("PUT") - - self.assertEquals(counter.render(), [ - 'vector{method="GET"} 2', - 'vector{method="PUT"} 1', - ]) - - -class CallbackMetricTestCase(unittest.TestCase): - - def test_scalar(self): - d = dict() - - metric = CallbackMetric("size", lambda: len(d)) - - self.assertEquals(metric.render(), [ - 'size 0', - ]) - - d["key"] = "value" - - self.assertEquals(metric.render(), [ - 'size 1', - ]) - - def test_vector(self): - vals = dict() - - metric = CallbackMetric("values", lambda: vals, labels=["type"]) - - self.assertEquals(metric.render(), []) - - # Keys have to be tuples, even if they're 1-element - vals[("foo",)] = 1 - vals[("bar",)] = 2 - - self.assertEquals(metric.render(), [ - 'values{type="bar"} 2', - 'values{type="foo"} 1', - ]) - - -class DistributionMetricTestCase(unittest.TestCase): - - def test_scalar(self): - metric = DistributionMetric("thing") - - self.assertEquals(metric.render(), [ - 'thing:count 0', - 'thing:total 0', - ]) - - metric.inc_by(500) - - self.assertEquals(metric.render(), [ - 'thing:count 1', - 'thing:total 500', - ]) - - def test_vector(self): - metric = DistributionMetric("queries", labels=["verb"]) - - self.assertEquals(metric.render(), []) - - metric.inc_by(300, "SELECT") - metric.inc_by(200, "SELECT") - metric.inc_by(800, "INSERT") - - self.assertEquals(metric.render(), [ - 'queries:count{verb="INSERT"} 1', - 'queries:count{verb="SELECT"} 2', - 'queries:total{verb="INSERT"} 800', - 'queries:total{verb="SELECT"} 500', - ]) - - -class CacheMetricTestCase(unittest.TestCase): - - def test_cache(self): - d = dict() - - metric = CacheMetric("cache", lambda: len(d), "cache_name") - - self.assertEquals(metric.render(), [ - 'cache:hits{name="cache_name"} 0', - 'cache:total{name="cache_name"} 0', - 'cache:size{name="cache_name"} 0', - ]) - - metric.inc_misses() - d["key"] = "value" - - self.assertEquals(metric.render(), [ - 'cache:hits{name="cache_name"} 0', - 'cache:total{name="cache_name"} 1', - 'cache:size{name="cache_name"} 1', - ]) - - metric.inc_hits() - - self.assertEquals(metric.render(), [ - 'cache:hits{name="cache_name"} 1', - 'cache:total{name="cache_name"} 2', - 'cache:size{name="cache_name"} 1', - ]) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 81063f19a1..8708c8a196 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -12,15 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer, reactor -from tests import unittest +import tempfile from mock import Mock, NonCallableMock -from tests.utils import setup_test_homeserver -from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory + +from twisted.internet import defer, reactor + from synapse.replication.tcp.client import ( - ReplicationClientHandler, ReplicationClientFactory, + ReplicationClientFactory, + ReplicationClientHandler, ) +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory + +from tests import unittest +from tests.utils import setup_test_homeserver class BaseSlavedStoreTestCase(unittest.TestCase): @@ -29,7 +34,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): self.hs = yield setup_test_homeserver( "blue", http_client=None, - replication_layer=Mock(), + federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=[ "send_message", ]), @@ -41,7 +46,9 @@ class BaseSlavedStoreTestCase(unittest.TestCase): self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) - listener = reactor.listenUNIX("\0xxx", server_factory) + # XXX: mktemp is unsafe and should never be used. but we're just a test. + path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket") + listener = reactor.listenUNIX(path, server_factory) self.addCleanup(listener.stopListening) self.streamer = server_factory.streamer @@ -49,7 +56,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): client_factory = ReplicationClientFactory( self.hs, "client_name", self.replication_handler ) - client_connector = reactor.connectUNIX("\0xxx", client_factory) + client_connector = reactor.connectUNIX(path, client_factory) self.addCleanup(client_factory.stopTrying) self.addCleanup(client_connector.disconnect) diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py index da54d478ce..adf226404e 100644 --- a/tests/replication/slave/storage/test_account_data.py +++ b/tests/replication/slave/storage/test_account_data.py @@ -13,11 +13,11 @@ # limitations under the License. -from ._base import BaseSlavedStoreTestCase +from twisted.internet import defer from synapse.replication.slave.storage.account_data import SlavedAccountDataStore -from twisted.internet import defer +from ._base import BaseSlavedStoreTestCase USER_ID = "@feeling:blue" TYPE = "my.type" @@ -37,10 +37,6 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase): "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1} ) - yield self.check( - "get_global_account_data_by_type_for_users", - [TYPE, [USER_ID]], {USER_ID: {"a": 1}} - ) yield self.master_store.add_account_data_for_user( USER_ID, TYPE, {"a": 2} @@ -50,7 +46,3 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase): "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2} ) - yield self.check( - "get_global_account_data_by_type_for_users", - [TYPE, [USER_ID]], {USER_ID: {"a": 2}} - ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 105e1228bb..f5b47f5ec0 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStoreTestCase +from twisted.internet import defer from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events.snapshot import EventContext from synapse.replication.slave.storage.events import SlavedEventStore from synapse.storage.roommember import RoomsForUser -from twisted.internet import defer - +from ._base import BaseSlavedStoreTestCase USER_ID = "@feeling:blue" USER_ID_2 = "@bright:blue" @@ -223,16 +222,21 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): state_ids = { key: e.event_id for key, e in state.items() } - context = EventContext() - context.current_state_ids = state_ids - context.prev_state_ids = state_ids - elif not backfill: + context = EventContext.with_state( + state_group=None, + current_state_ids=state_ids, + prev_state_ids=state_ids + ) + else: state_handler = self.hs.get_state_handler() context = yield state_handler.compute_event_context(event) - else: - context = EventContext() - context.push_actions = push_actions + yield self.master_store.add_push_actions_to_staging( + event.event_id, { + user_id: actions + for user_id, actions in push_actions + }, + ) ordering = None if backfill: diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index 6624fe4eea..e6d670cc1f 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStoreTestCase +from twisted.internet import defer from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from twisted.internet import defer +from ._base import BaseSlavedStoreTestCase USER_ID = "@feeling:blue" ROOM_ID = "!room:blue" diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index d7cea30260..34e68ae82f 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -1,7 +1,11 @@ -from synapse.rest.client.transactions import HttpTransactionCache -from synapse.rest.client.transactions import CLEANUP_PERIOD_MS -from twisted.internet import defer from mock import Mock, call + +from twisted.internet import defer, reactor + +from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache +from synapse.util import Clock +from synapse.util.logcontext import LoggingContext + from tests import unittest from tests.utils import MockClock @@ -10,7 +14,10 @@ class HttpTransactionCacheTestCase(unittest.TestCase): def setUp(self): self.clock = MockClock() - self.cache = HttpTransactionCache(self.clock) + self.hs = Mock() + self.hs.get_clock = Mock(return_value=self.clock) + self.hs.get_auth = Mock() + self.cache = HttpTransactionCache(self.hs) self.mock_http_response = (200, "GOOD JOB!") self.mock_key = "foo" @@ -40,6 +47,78 @@ class HttpTransactionCacheTestCase(unittest.TestCase): cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0) @defer.inlineCallbacks + def test_logcontexts_with_async_result(self): + @defer.inlineCallbacks + def cb(): + yield Clock(reactor).sleep(0) + defer.returnValue("yay") + + @defer.inlineCallbacks + def test(): + with LoggingContext("c") as c1: + res = yield self.cache.fetch_or_execute(self.mock_key, cb) + self.assertIs(LoggingContext.current_context(), c1) + self.assertEqual(res, "yay") + + # run the test twice in parallel + d = defer.gatherResults([test(), test()]) + self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + yield d + self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + + @defer.inlineCallbacks + def test_does_not_cache_exceptions(self): + """Checks that, if the callback throws an exception, it is called again + for the next request. + """ + called = [False] + + def cb(): + if called[0]: + # return a valid result the second time + return defer.succeed(self.mock_http_response) + + called[0] = True + raise Exception("boo") + + with LoggingContext("test") as test_context: + try: + yield self.cache.fetch_or_execute(self.mock_key, cb) + except Exception as e: + self.assertEqual(e.message, "boo") + self.assertIs(LoggingContext.current_context(), test_context) + + res = yield self.cache.fetch_or_execute(self.mock_key, cb) + self.assertEqual(res, self.mock_http_response) + self.assertIs(LoggingContext.current_context(), test_context) + + @defer.inlineCallbacks + def test_does_not_cache_failures(self): + """Checks that, if the callback returns a failure, it is called again + for the next request. + """ + called = [False] + + def cb(): + if called[0]: + # return a valid result the second time + return defer.succeed(self.mock_http_response) + + called[0] = True + return defer.fail(Exception("boo")) + + with LoggingContext("test") as test_context: + try: + yield self.cache.fetch_or_execute(self.mock_key, cb) + except Exception as e: + self.assertEqual(e.message, "boo") + self.assertIs(LoggingContext.current_context(), test_context) + + res = yield self.cache.fetch_or_execute(self.mock_key, cb) + self.assertEqual(res, self.mock_http_response) + self.assertIs(LoggingContext.current_context(), test_context) + + @defer.inlineCallbacks def test_cleans_up(self): cb = Mock( return_value=defer.succeed(self.mock_http_response) diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py new file mode 100644 index 0000000000..8c90145601 --- /dev/null +++ b/tests/rest/client/v1/test_admin.py @@ -0,0 +1,305 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import hmac +import json + +from mock import Mock + +from synapse.http.server import JsonResource +from synapse.rest.client.v1.admin import register_servlets +from synapse.util import Clock + +from tests import unittest +from tests.server import ( + ThreadedMemoryReactorClock, + make_request, + render, + setup_test_homeserver, +) + + +class UserRegisterTestCase(unittest.TestCase): + def setUp(self): + + self.clock = ThreadedMemoryReactorClock() + self.hs_clock = Clock(self.clock) + self.url = "/_matrix/client/r0/admin/register" + + self.registration_handler = Mock() + self.identity_handler = Mock() + self.login_handler = Mock() + self.device_handler = Mock() + self.device_handler.check_device_registered = Mock(return_value="FAKE") + + self.datastore = Mock(return_value=Mock()) + self.datastore.get_current_state_deltas = Mock(return_value=[]) + + self.secrets = Mock() + + self.hs = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.clock + ) + + self.hs.config.registration_shared_secret = u"shared" + + self.hs.get_media_repository = Mock() + self.hs.get_deactivate_account_handler = Mock() + + self.resource = JsonResource(self.hs) + register_servlets(self.hs, self.resource) + + def test_disabled(self): + """ + If there is no shared secret, registration through this method will be + prevented. + """ + self.hs.config.registration_shared_secret = None + + request, channel = make_request("POST", self.url, b'{}') + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + 'Shared secret registration is not enabled', channel.json_body["error"] + ) + + def test_get_nonce(self): + """ + Calling GET on the endpoint will return a randomised nonce, using the + homeserver's secrets provider. + """ + secrets = Mock() + secrets.token_hex = Mock(return_value="abcd") + + self.hs.get_secrets = Mock(return_value=secrets) + + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + + self.assertEqual(channel.json_body, {"nonce": "abcd"}) + + def test_expired_nonce(self): + """ + Calling GET on the endpoint will return a randomised nonce, which will + only last for SALT_TIMEOUT (60s). + """ + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + nonce = channel.json_body["nonce"] + + # 59 seconds + self.clock.advance(59) + + body = json.dumps({"nonce": nonce}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('username must be specified', channel.json_body["error"]) + + # 61 seconds + self.clock.advance(2) + + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('unrecognised nonce', channel.json_body["error"]) + + def test_register_incorrect_nonce(self): + """ + Only the provided nonce can be used, as it's checked in the MAC. + """ + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + ).encode('utf8') + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("HMAC incorrect", channel.json_body["error"]) + + def test_register_correct_nonce(self): + """ + When the correct nonce is provided, and the right key is provided, the + user is registered. + """ + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + ).encode('utf8') + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["user_id"]) + + def test_nonce_reuse(self): + """ + A valid unrecognised nonce. + """ + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + ).encode('utf8') + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["user_id"]) + + # Now, try and reuse it + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('unrecognised nonce', channel.json_body["error"]) + + def test_missing_parts(self): + """ + Synapse will complain if you don't give nonce, username, password, and + mac. Admin is optional. Additional checks are done for length and + type. + """ + def nonce(): + request, channel = make_request("GET", self.url) + render(request, self.resource, self.clock) + return channel.json_body["nonce"] + + # + # Nonce check + # + + # Must be present + body = json.dumps({}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('nonce must be specified', channel.json_body["error"]) + + # + # Username checks + # + + # Must be present + body = json.dumps({"nonce": nonce()}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('username must be specified', channel.json_body["error"]) + + # Must be a string + body = json.dumps({"nonce": nonce(), "username": 1234}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid username', channel.json_body["error"]) + + # Must not have null bytes + body = json.dumps({"nonce": nonce(), "username": b"abcd\x00"}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid username', channel.json_body["error"]) + + # Must not have null bytes + body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid username', channel.json_body["error"]) + + # + # Username checks + # + + # Must be present + body = json.dumps({"nonce": nonce(), "username": "a"}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('password must be specified', channel.json_body["error"]) + + # Must be a string + body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid password', channel.json_body["error"]) + + # Must not have null bytes + body = json.dumps({"nonce": nonce(), "username": "a", "password": b"abcd\x00"}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid password', channel.json_body["error"]) + + # Super long + body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) + request, channel = make_request("POST", self.url, body.encode('utf8')) + render(request, self.resource, self.clock) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual('Invalid password', channel.json_body["error"]) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index e9698bfdc9..50418153fa 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -14,107 +14,35 @@ # limitations under the License. """ Tests REST events for /events paths.""" -from tests import unittest -# twisted imports -from twisted.internet import defer - -import synapse.rest.client.v1.events -import synapse.rest.client.v1.register -import synapse.rest.client.v1.room +from mock import Mock, NonCallableMock +from six import PY3 +from twisted.internet import defer from ....utils import MockHttpResource, setup_test_homeserver from .utils import RestTestCase -from mock import Mock, NonCallableMock - - PATH_PREFIX = "/_matrix/client/api/v1" -class EventStreamPaginationApiTestCase(unittest.TestCase): - """ Tests event streaming query parameters and start/end keys used in the - Pagination stream API. """ - user_id = "sid1" - - def setUp(self): - # configure stream and inject items - pass - - def tearDown(self): - pass - - def TODO_test_long_poll(self): - # stream from 'end' key, send (self+other) message, expect message. - - # stream from 'END', send (self+other) message, expect message. - - # stream from 'end' key, send (self+other) topic, expect topic. - - # stream from 'END', send (self+other) topic, expect topic. - - # stream from 'end' key, send (self+other) invite, expect invite. - - # stream from 'END', send (self+other) invite, expect invite. - - pass - - def TODO_test_stream_forward(self): - # stream from START, expect injected items - - # stream from 'start' key, expect same content - - # stream from 'end' key, expect nothing - - # stream from 'END', expect nothing - - # The following is needed for cases where content is removed e.g. you - # left a room, so the token you're streaming from is > the one that - # would be returned naturally from START>END. - # stream from very new token (higher than end key), expect same token - # returned as end key - pass - - def TODO_test_limits(self): - # stream from a key, expect limit_num items - - # stream from START, expect limit_num items - - pass - - def TODO_test_range(self): - # stream from key to key, expect X items - - # stream from key to END, expect X items - - # stream from START to key, expect X items - - # stream from START to END, expect all items - pass - - def TODO_test_direction(self): - # stream from END to START and fwds, expect newest first - - # stream from END to START and bwds, expect oldest first - - # stream from START to END and fwds, expect oldest first - - # stream from START to END and bwds, expect newest first - - pass - - class EventStreamPermissionsTestCase(RestTestCase): """ Tests event streaming (GET /events). """ + if PY3: + skip = "Skip on Py3 until ported to use not V1 only register." + @defer.inlineCallbacks def setUp(self): + import synapse.rest.client.v1.events + import synapse.rest.client.v1_only.register + import synapse.rest.client.v1.room + self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) hs = yield setup_test_homeserver( http_client=None, - replication_layer=Mock(), + federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=[ "send_message", ]), @@ -123,10 +51,11 @@ class EventStreamPermissionsTestCase(RestTestCase): self.ratelimiter.send_message.return_value = (True, 0) hs.config.enable_registration_captcha = False hs.config.enable_registration = True + hs.config.auto_join_rooms = [] hs.get_handlers().federation_handler = Mock() - synapse.rest.client.v1.register.register_servlets(hs, self.mock_resource) + synapse.rest.client.v1_only.register.register_servlets(hs, self.mock_resource) synapse.rest.client.v1.events.register_servlets(hs, self.mock_resource) synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) @@ -147,11 +76,16 @@ class EventStreamPermissionsTestCase(RestTestCase): @defer.inlineCallbacks def test_stream_basic_permissions(self): - # invalid token, expect 403 + # invalid token, expect 401 + # note: this is in violation of the original v1 spec, which expected + # 403. However, since the v1 spec no longer exists and the v1 + # implementation is now part of the r0 implementation, the newer + # behaviour is used instead to be consistent with the r0 spec. + # see issue #2602 (code, response) = yield self.mock_resource.trigger_get( "/events?access_token=%s" % ("invalid" + self.token, ) ) - self.assertEquals(403, code, msg=str(response)) + self.assertEquals(401, code, msg=str(response)) # valid token, expect content (code, response) = yield self.mock_resource.trigger_get( diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 1e95e97538..d71cc8e0db 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -15,12 +15,15 @@ """Tests REST events for /profile paths.""" from mock import Mock + from twisted.internet import defer import synapse.types -from synapse.api.errors import SynapseError, AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.rest.client.v1 import profile + from tests import unittest + from ....utils import MockHttpResource, setup_test_homeserver myid = "@1234ABCD:test" @@ -45,15 +48,14 @@ class ProfileTestCase(unittest.TestCase): http_client=None, resource_for_client=self.mock_resource, federation=Mock(), - replication_layer=Mock(), + federation_client=Mock(), + profile_handler=self.mock_handler ) def _get_user_by_req(request=None, allow_guest=False): return synapse.types.create_requester(myid) - hs.get_v1auth().get_user_by_req = _get_user_by_req - - hs.get_handlers().profile_handler = self.mock_handler + hs.get_auth().get_user_by_req = _get_user_by_req profile.register_servlets(hs, self.mock_resource) diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index a6a4e2ffe0..83a23cd8fe 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -13,26 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.rest.client.v1.register import CreateUserRestServlet -from twisted.internet import defer +import json + from mock import Mock +from six import PY3 + +from twisted.test.proto_helpers import MemoryReactorClock + +from synapse.http.server import JsonResource +from synapse.rest.client.v1_only.register import register_servlets +from synapse.util import Clock + from tests import unittest -from tests.utils import mock_getRawHeaders -import json +from tests.server import make_request, setup_test_homeserver class CreateUserServletTestCase(unittest.TestCase): + """ + Tests for CreateUserRestServlet. + """ + if PY3: + skip = "Not ported to Python 3." def setUp(self): - # do the dance to hook up request data to self.request_data - self.request_data = "" - self.request = Mock( - content=Mock(read=Mock(side_effect=lambda: self.request_data)), - path='/_matrix/client/api/v1/createUser' - ) - self.request.args = {} - self.request.requestHeaders.getRawHeaders = mock_getRawHeaders() - self.registration_handler = Mock() self.appservice = Mock(sender="@as:test") @@ -40,39 +43,49 @@ class CreateUserServletTestCase(unittest.TestCase): get_app_service_by_token=Mock(return_value=self.appservice) ) - # do the dance to hook things up to the hs global - handlers = Mock( - registration_handler=self.registration_handler, + handlers = Mock(registration_handler=self.registration_handler) + self.clock = MemoryReactorClock() + self.hs_clock = Clock(self.clock) + + self.hs = self.hs = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.clock ) - self.hs = Mock() - self.hs.hostname = "superbig~testing~thing.com" self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.get_handlers = Mock(return_value=handlers) - self.servlet = CreateUserRestServlet(self.hs) - @defer.inlineCallbacks def test_POST_createuser_with_valid_user(self): + + res = JsonResource(self.hs) + register_servlets(self.hs, res) + + request_data = json.dumps( + { + "localpart": "someone", + "displayname": "someone interesting", + "duration_seconds": 200, + } + ) + + url = b'/_matrix/client/api/v1/createUser?access_token=i_am_an_app_service' + user_id = "@someone:interesting" token = "my token" - self.request.args = { - "access_token": "i_am_an_app_service" - } - self.request_data = json.dumps({ - "localpart": "someone", - "displayname": "someone interesting", - "duration_seconds": 200 - }) self.registration_handler.get_or_create_user = Mock( return_value=(user_id, token) ) - (code, result) = yield self.servlet.on_POST(self.request) - self.assertEquals(code, 200) + request, channel = make_request(b"POST", url, request_data) + request.render(res) + + # Advance the clock because it waits + self.clock.advance(1) + + self.assertEquals(channel.result["code"], b"200") det_data = { "user_id": user_id, "access_token": token, - "home_server": self.hs.hostname + "home_server": self.hs.hostname, } - self.assertDictContainsSubset(det_data, result) + self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index d746ea8568..00fc796787 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -15,963 +15,782 @@ """Tests REST events for /rooms paths.""" -# twisted imports +import json + +from mock import Mock, NonCallableMock +from six.moves.urllib import parse as urlparse + from twisted.internet import defer import synapse.rest.client.v1.room from synapse.api.constants import Membership - +from synapse.http.server import JsonResource from synapse.types import UserID +from synapse.util import Clock -import json -import urllib - -from ....utils import MockHttpResource, setup_test_homeserver -from .utils import RestTestCase +from tests import unittest +from tests.server import ( + ThreadedMemoryReactorClock, + make_request, + render, + setup_test_homeserver, +) -from mock import Mock, NonCallableMock +from .utils import RestHelper -PATH_PREFIX = "/_matrix/client/api/v1" +PATH_PREFIX = b"/_matrix/client/api/v1" -class RoomPermissionsTestCase(RestTestCase): - """ Tests room permissions. """ - user_id = "@sid1:red" - rmcreator_id = "@notme:red" +class RoomBase(unittest.TestCase): + rmcreator_id = None - @defer.inlineCallbacks def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - hs = yield setup_test_homeserver( + self.clock = ThreadedMemoryReactorClock() + self.hs_clock = Clock(self.clock) + + self.hs = setup_test_homeserver( "red", http_client=None, - replication_layer=Mock(), + clock=self.hs_clock, + reactor=self.clock, + federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) - self.ratelimiter = hs.get_ratelimiter() + self.ratelimiter = self.hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) - hs.get_handlers().federation_handler = Mock() + self.hs.get_federation_handler = Mock(return_value=Mock()) def get_user_by_access_token(token=None, allow_guest=False): return { - "user": UserID.from_string(self.auth_user_id), + "user": UserID.from_string(self.helper.auth_user_id), "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + + def get_user_by_req(request, allow_guest=False, rights="access"): + return synapse.types.create_requester( + UserID.from_string(self.helper.auth_user_id), 1, False, None + ) + + self.hs.get_auth().get_user_by_req = get_user_by_req + self.hs.get_auth().get_user_by_access_token = get_user_by_access_token + self.hs.get_auth().get_access_token_from_request = Mock(return_value=b"1234") def _insert_client_ip(*args, **kwargs): return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - self.auth_user_id = self.rmcreator_id + self.hs.get_datastore().insert_client_ip = _insert_client_ip - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + self.resource = JsonResource(self.hs) + synapse.rest.client.v1.room.register_servlets(self.hs, self.resource) + synapse.rest.client.v1.room.register_deprecated_servlets(self.hs, self.resource) + self.helper = RestHelper(self.hs, self.resource, self.user_id) - self.auth = hs.get_v1auth() - # create some rooms under the name rmcreator_id - self.uncreated_rmid = "!aa:test" +class RoomPermissionsTestCase(RoomBase): + """ Tests room permissions. """ + + user_id = b"@sid1:red" + rmcreator_id = b"@notme:red" + + def setUp(self): - self.created_rmid = yield self.create_room_as(self.rmcreator_id, - is_public=False) + super(RoomPermissionsTestCase, self).setUp() - self.created_public_rmid = yield self.create_room_as(self.rmcreator_id, - is_public=True) + self.helper.auth_user_id = self.rmcreator_id + # create some rooms under the name rmcreator_id + self.uncreated_rmid = "!aa:test" + self.created_rmid = self.helper.create_room_as( + self.rmcreator_id, is_public=False + ) + self.created_public_rmid = self.helper.create_room_as( + self.rmcreator_id, is_public=True + ) # send a message in one of the rooms self.created_rmid_msg_path = ( - "/rooms/%s/send/m.room.message/a1" % (self.created_rmid) - ) - (code, response) = yield self.mock_resource.trigger( - "PUT", + "rooms/%s/send/m.room.message/a1" % (self.created_rmid) + ).encode('ascii') + request, channel = make_request( + b"PUT", self.created_rmid_msg_path, - '{"msgtype":"m.text","body":"test msg"}' + b'{"msgtype":"m.text","body":"test msg"}', ) - self.assertEquals(200, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(channel.result["code"], b"200", channel.result) # set topic for public room - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/rooms/%s/state/m.room.topic" % self.created_public_rmid, - '{"topic":"Public Room Topic"}' + request, channel = make_request( + b"PUT", + ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'), + b'{"topic":"Public Room Topic"}', ) - self.assertEquals(200, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(channel.result["code"], b"200", channel.result) # auth as user_id now - self.auth_user_id = self.user_id - - def tearDown(self): - pass + self.helper.auth_user_id = self.user_id - @defer.inlineCallbacks def test_send_message(self): - msg_content = '{"msgtype":"m.text","body":"hello"}' - send_msg_path = ( - "/rooms/%s/send/m.room.message/mid1" % (self.created_rmid,) - ) + msg_content = b'{"msgtype":"m.text","body":"hello"}' + + seq = iter(range(100)) + + def send_msg_path(): + return b"/rooms/%s/send/m.room.message/mid%s" % ( + self.created_rmid, + str(next(seq)).encode('ascii'), + ) # send message in uncreated room, expect 403 - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), - msg_content + request, channel = make_request( + b"PUT", + b"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), + msg_content, ) - self.assertEquals(403, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 - (code, response) = yield self.mock_resource.trigger( - "PUT", - send_msg_path, - msg_content - ) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"PUT", send_msg_path(), msg_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # send message in created room and invited, expect 403 - yield self.invite( - room=self.created_rmid, - src=self.rmcreator_id, - targ=self.user_id - ) - (code, response) = yield self.mock_resource.trigger( - "PUT", - send_msg_path, - msg_content + self.helper.invite( + room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"PUT", send_msg_path(), msg_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # send message in created room and joined, expect 200 - yield self.join(room=self.created_rmid, user=self.user_id) - (code, response) = yield self.mock_resource.trigger( - "PUT", - send_msg_path, - msg_content - ) - self.assertEquals(200, code, msg=str(response)) + self.helper.join(room=self.created_rmid, user=self.user_id) + request, channel = make_request(b"PUT", send_msg_path(), msg_content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) # send message in created room and left, expect 403 - yield self.leave(room=self.created_rmid, user=self.user_id) - (code, response) = yield self.mock_resource.trigger( - "PUT", - send_msg_path, - msg_content - ) - self.assertEquals(403, code, msg=str(response)) + self.helper.leave(room=self.created_rmid, user=self.user_id) + request, channel = make_request(b"PUT", send_msg_path(), msg_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_topic_perms(self): - topic_content = '{"topic":"My Topic Name"}' - topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid + topic_content = b'{"topic":"My Topic Name"}' + topic_path = b"/rooms/%s/state/m.room.topic" % self.created_rmid # set/get topic in uncreated room, expect 403 - (code, response) = yield self.mock_resource.trigger( - "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, - topic_content + request, channel = make_request( + b"PUT", b"/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - self.assertEquals(403, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/state/m.room.topic" % self.uncreated_rmid + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = make_request( + b"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - self.assertEquals(403, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 - (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content - ) - self.assertEquals(403, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger_get(topic_path) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"PUT", topic_path, topic_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = make_request(b"GET", topic_path) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 - yield self.invite( + self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content - ) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"PUT", topic_path, topic_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 - (code, response) = yield self.mock_resource.trigger_get(topic_path) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"GET", topic_path) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 - yield self.join(room=self.created_rmid, user=self.user_id) + self.helper.join(room=self.created_rmid, user=self.user_id) # Only room ops can set topic by default - self.auth_user_id = self.rmcreator_id - (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content - ) - self.assertEquals(200, code, msg=str(response)) - self.auth_user_id = self.user_id + self.helper.auth_user_id = self.rmcreator_id + request, channel = make_request(b"PUT", topic_path, topic_content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.helper.auth_user_id = self.user_id - (code, response) = yield self.mock_resource.trigger_get(topic_path) - self.assertEquals(200, code, msg=str(response)) - self.assert_dict(json.loads(topic_content), response) + request, channel = make_request(b"GET", topic_path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assert_dict(json.loads(topic_content), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 - yield self.leave(room=self.created_rmid, user=self.user_id) - (code, response) = yield self.mock_resource.trigger( - "PUT", topic_path, topic_content - ) - self.assertEquals(403, code, msg=str(response)) - (code, response) = yield self.mock_resource.trigger_get(topic_path) - self.assertEquals(200, code, msg=str(response)) + self.helper.leave(room=self.created_rmid, user=self.user_id) + request, channel = make_request(b"PUT", topic_path, topic_content) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = make_request(b"GET", topic_path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/state/m.room.topic" % self.created_public_rmid + request, channel = make_request( + b"GET", b"/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - self.assertEquals(403, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/rooms/%s/state/m.room.topic" % self.created_public_rmid, - topic_content + request, channel = make_request( + b"PUT", + b"/rooms/%s/state/m.room.topic" % self.created_public_rmid, + topic_content, ) - self.assertEquals(403, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def _test_get_membership(self, room=None, members=[], expect_code=None): for member in members: - path = "/rooms/%s/state/m.room.member/%s" % (room, member) - (code, response) = yield self.mock_resource.trigger_get(path) - self.assertEquals(expect_code, code) + path = b"/rooms/%s/state/m.room.member/%s" % (room, member) + request, channel = make_request(b"GET", path) + render(request, self.resource, self.clock) + self.assertEquals(expect_code, int(channel.result["code"])) - @defer.inlineCallbacks def test_membership_basic_room_perms(self): # === room does not exist === room = self.uncreated_rmid # get membership of self, get membership of other, uncreated room # expect all 403s - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=403) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) # trying to invite people to this room should 403 - yield self.invite(room=room, src=self.user_id, targ=self.rmcreator_id, - expect_code=403) + self.helper.invite( + room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 + ) # set [invite/join/left] of self, set [invite/join/left] of other, # expect all 404s because room doesn't exist on any server for usr in [self.user_id, self.rmcreator_id]: - yield self.join(room=room, user=usr, expect_code=404) - yield self.leave(room=room, user=usr, expect_code=404) + self.helper.join(room=room, user=usr, expect_code=404) + self.helper.leave(room=room, user=usr, expect_code=404) - @defer.inlineCallbacks def test_membership_private_room_perms(self): room = self.created_rmid # get membership of self, get membership of other, private room + invite # expect all 403s - yield self.invite(room=room, src=self.rmcreator_id, - targ=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=403) + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) # get membership of self, get membership of other, private room + joined # expect all 200s - yield self.join(room=room, user=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=200) + self.helper.join(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) # get membership of self, get membership of other, private room + left # expect all 200s - yield self.leave(room=room, user=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=200) + self.helper.leave(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) - @defer.inlineCallbacks def test_membership_public_room_perms(self): room = self.created_public_rmid # get membership of self, get membership of other, public room + invite # expect 403 - yield self.invite(room=room, src=self.rmcreator_id, - targ=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=403) + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) # get membership of self, get membership of other, public room + joined # expect all 200s - yield self.join(room=room, user=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=200) + self.helper.join(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) # get membership of self, get membership of other, public room + left # expect 200. - yield self.leave(room=room, user=self.user_id) - yield self._test_get_membership( - members=[self.user_id, self.rmcreator_id], - room=room, expect_code=200) + self.helper.leave(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) - @defer.inlineCallbacks def test_invited_permissions(self): room = self.created_rmid - yield self.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) # set [invite/join/left] of other user, expect 403s - yield self.invite(room=room, src=self.user_id, targ=self.rmcreator_id, - expect_code=403) - yield self.change_membership(room=room, src=self.user_id, - targ=self.rmcreator_id, - membership=Membership.JOIN, - expect_code=403) - yield self.change_membership(room=room, src=self.user_id, - targ=self.rmcreator_id, - membership=Membership.LEAVE, - expect_code=403) - - @defer.inlineCallbacks + self.helper.invite( + room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 + ) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=self.rmcreator_id, + membership=Membership.JOIN, + expect_code=403, + ) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=self.rmcreator_id, + membership=Membership.LEAVE, + expect_code=403, + ) + def test_joined_permissions(self): room = self.created_rmid - yield self.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - yield self.join(room=room, user=self.user_id) + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self.helper.join(room=room, user=self.user_id) # set invited of self, expect 403 - yield self.invite(room=room, src=self.user_id, targ=self.user_id, - expect_code=403) + self.helper.invite( + room=room, src=self.user_id, targ=self.user_id, expect_code=403 + ) # set joined of self, expect 200 (NOOP) - yield self.join(room=room, user=self.user_id) + self.helper.join(room=room, user=self.user_id) other = "@burgundy:red" # set invited of other, expect 200 - yield self.invite(room=room, src=self.user_id, targ=other, - expect_code=200) + self.helper.invite(room=room, src=self.user_id, targ=other, expect_code=200) # set joined of other, expect 403 - yield self.change_membership(room=room, src=self.user_id, - targ=other, - membership=Membership.JOIN, - expect_code=403) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=other, + membership=Membership.JOIN, + expect_code=403, + ) # set left of other, expect 403 - yield self.change_membership(room=room, src=self.user_id, - targ=other, - membership=Membership.LEAVE, - expect_code=403) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=other, + membership=Membership.LEAVE, + expect_code=403, + ) # set left of self, expect 200 - yield self.leave(room=room, user=self.user_id) + self.helper.leave(room=room, user=self.user_id) - @defer.inlineCallbacks def test_leave_permissions(self): room = self.created_rmid - yield self.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - yield self.join(room=room, user=self.user_id) - yield self.leave(room=room, user=self.user_id) + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self.helper.join(room=room, user=self.user_id) + self.helper.leave(room=room, user=self.user_id) # set [invite/join/left] of self, set [invite/join/left] of other, # expect all 403s for usr in [self.user_id, self.rmcreator_id]: - yield self.change_membership( + self.helper.change_membership( room=room, src=self.user_id, targ=usr, membership=Membership.INVITE, - expect_code=403 + expect_code=403, ) - yield self.change_membership( + self.helper.change_membership( room=room, src=self.user_id, targ=usr, membership=Membership.JOIN, - expect_code=403 + expect_code=403, ) # It is always valid to LEAVE if you've already left (currently.) - yield self.change_membership( + self.helper.change_membership( room=room, src=self.user_id, targ=self.rmcreator_id, membership=Membership.LEAVE, - expect_code=403 + expect_code=403, ) -class RoomsMemberListTestCase(RestTestCase): +class RoomsMemberListTestCase(RoomBase): """ Tests /rooms/$room_id/members/list REST events.""" - user_id = "@sid1:red" - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - - hs = yield setup_test_homeserver( - "red", - http_client=None, - replication_layer=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() - self.auth_user_id = self.user_id - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) - - def tearDown(self): - pass + user_id = b"@sid1:red" - @defer.inlineCallbacks def test_get_member_list(self): - room_id = yield self.create_room_as(self.user_id) - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/members" % room_id - ) - self.assertEquals(200, code, msg=str(response)) + room_id = self.helper.create_room_as(self.user_id) + request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_get_member_list_no_room(self): - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/roomdoesnotexist/members" - ) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"GET", b"/rooms/roomdoesnotexist/members") + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_get_member_list_no_permission(self): - room_id = yield self.create_room_as("@some_other_guy:red") - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/members" % room_id - ) - self.assertEquals(403, code, msg=str(response)) + room_id = self.helper.create_room_as(b"@some_other_guy:red") + request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_get_member_list_mixed_memberships(self): - room_creator = "@some_other_guy:red" - room_id = yield self.create_room_as(room_creator) - room_path = "/rooms/%s/members" % room_id - yield self.invite(room=room_id, src=room_creator, - targ=self.user_id) + room_creator = b"@some_other_guy:red" + room_id = self.helper.create_room_as(room_creator) + room_path = b"/rooms/%s/members" % room_id + self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. - (code, response) = yield self.mock_resource.trigger_get(room_path) - self.assertEquals(403, code, msg=str(response)) + request, channel = make_request(b"GET", room_path) + render(request, self.resource, self.clock) + self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - yield self.join(room=room_id, user=self.user_id) + self.helper.join(room=room_id, user=self.user_id) # can see list now joined - (code, response) = yield self.mock_resource.trigger_get(room_path) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"GET", room_path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - yield self.leave(room=room_id, user=self.user_id) + self.helper.leave(room=room_id, user=self.user_id) # can see old list once left - (code, response) = yield self.mock_resource.trigger_get(room_path) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"GET", room_path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) -class RoomsCreateTestCase(RestTestCase): +class RoomsCreateTestCase(RoomBase): """ Tests /rooms and /rooms/$room_id REST events. """ - user_id = "@sid1:red" - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id - - hs = yield setup_test_homeserver( - "red", - http_client=None, - replication_layer=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip + user_id = b"@sid1:red" - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) - - def tearDown(self): - pass - - @defer.inlineCallbacks def test_post_room_no_keys(self): # POST with no config keys, expect new room id - (code, response) = yield self.mock_resource.trigger("POST", - "/createRoom", - "{}") - self.assertEquals(200, code, response) - self.assertTrue("room_id" in response) + request, channel = make_request(b"POST", b"/createRoom", b"{}") + + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), channel.result) + self.assertTrue("room_id" in channel.json_body) - @defer.inlineCallbacks def test_post_room_visibility_key(self): # POST with visibility config key, expect new room id - (code, response) = yield self.mock_resource.trigger( - "POST", - "/createRoom", - '{"visibility":"private"}') - self.assertEquals(200, code) - self.assertTrue("room_id" in response) - - @defer.inlineCallbacks + request, channel = make_request( + b"POST", b"/createRoom", b'{"visibility":"private"}' + ) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) + self.assertTrue("room_id" in channel.json_body) + def test_post_room_custom_key(self): # POST with custom config keys, expect new room id - (code, response) = yield self.mock_resource.trigger( - "POST", - "/createRoom", - '{"custom":"stuff"}') - self.assertEquals(200, code) - self.assertTrue("room_id" in response) - - @defer.inlineCallbacks + request, channel = make_request(b"POST", b"/createRoom", b'{"custom":"stuff"}') + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) + self.assertTrue("room_id" in channel.json_body) + def test_post_room_known_and_unknown_keys(self): # POST with custom + known config keys, expect new room id - (code, response) = yield self.mock_resource.trigger( - "POST", - "/createRoom", - '{"visibility":"private","custom":"things"}') - self.assertEquals(200, code) - self.assertTrue("room_id" in response) - - @defer.inlineCallbacks + request, channel = make_request( + b"POST", b"/createRoom", b'{"visibility":"private","custom":"things"}' + ) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) + self.assertTrue("room_id" in channel.json_body) + def test_post_room_invalid_content(self): # POST with invalid content / paths, expect 400 - (code, response) = yield self.mock_resource.trigger( - "POST", - "/createRoom", - '{"visibili') - self.assertEquals(400, code) + request, channel = make_request(b"POST", b"/createRoom", b'{"visibili') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"])) - (code, response) = yield self.mock_resource.trigger( - "POST", - "/createRoom", - '["hello"]') - self.assertEquals(400, code) + request, channel = make_request(b"POST", b"/createRoom", b'["hello"]') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"])) -class RoomTopicTestCase(RestTestCase): +class RoomTopicTestCase(RoomBase): """ Tests /rooms/$room_id/topic REST events. """ - user_id = "@sid1:red" - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id - - hs = yield setup_test_homeserver( - "red", - http_client=None, - replication_layer=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - hs.get_handlers().federation_handler = Mock() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } + user_id = b"@sid1:red" - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip + def setUp(self): - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + super(RoomTopicTestCase, self).setUp() # create the room - self.room_id = yield self.create_room_as(self.user_id) - self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) + self.room_id = self.helper.create_room_as(self.user_id) + self.path = b"/rooms/%s/state/m.room.topic" % (self.room_id,) - def tearDown(self): - pass - - @defer.inlineCallbacks def test_invalid_puts(self): # missing keys or invalid json - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, '{}' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, '{}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, '{"_name":"bob"}' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, '{"_name":"bob"}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, '{"nao' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, '{"nao') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]' + request, channel = make_request( + b"PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]' ) - self.assertEquals(400, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, 'text only' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, 'text only') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, '' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, '') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) # valid key, wrong type content = '{"topic":["Topic name"]}' - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, content - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, content) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_rooms_topic(self): # nothing should be there - (code, response) = yield self.mock_resource.trigger_get(self.path) - self.assertEquals(404, code, msg=str(response)) + request, channel = make_request(b"GET", self.path) + render(request, self.resource, self.clock) + self.assertEquals(404, int(channel.result["code"]), msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, content - ) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) # valid get - (code, response) = yield self.mock_resource.trigger_get(self.path) - self.assertEquals(200, code, msg=str(response)) - self.assert_dict(json.loads(content), response) + request, channel = make_request(b"GET", self.path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assert_dict(json.loads(content), channel.json_body) - @defer.inlineCallbacks def test_rooms_topic_with_extra_keys(self): # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' - (code, response) = yield self.mock_resource.trigger( - "PUT", self.path, content - ) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"PUT", self.path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) # valid get - (code, response) = yield self.mock_resource.trigger_get(self.path) - self.assertEquals(200, code, msg=str(response)) - self.assert_dict(json.loads(content), response) + request, channel = make_request(b"GET", self.path) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assert_dict(json.loads(content), channel.json_body) -class RoomMemberStateTestCase(RestTestCase): +class RoomMemberStateTestCase(RoomBase): """ Tests /rooms/$room_id/members/$user_id/state REST events. """ - user_id = "@sid1:red" - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id - - hs = yield setup_test_homeserver( - "red", - http_client=None, - replication_layer=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip + user_id = b"@sid1:red" - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + def setUp(self): - self.room_id = yield self.create_room_as(self.user_id) + super(RoomMemberStateTestCase, self).setUp() + self.room_id = self.helper.create_room_as(self.user_id) def tearDown(self): pass - @defer.inlineCallbacks def test_invalid_puts(self): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json - (code, response) = yield self.mock_resource.trigger("PUT", path, '{}') - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '{"_name":"bob"}' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{"_name":"bob"}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '{"nao' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{"nao') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' + request, channel = make_request( + b"PUT", path, b'[{"_name":"bob"},{"_name":"jill"}]' ) - self.assertEquals(400, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, 'text only' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, 'text only') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) # valid keys, wrong types - content = ('{"membership":["%s","%s","%s"]}' % ( - Membership.INVITE, Membership.JOIN, Membership.LEAVE - )) - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(400, code, msg=str(response)) + content = '{"membership":["%s","%s","%s"]}' % ( + Membership.INVITE, + Membership.JOIN, + Membership.LEAVE, + ) + request, channel = make_request(b"PUT", path, content.encode('ascii')) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_rooms_members_self(self): path = "/rooms/%s/state/m.room.member/%s" % ( - urllib.quote(self.room_id), self.user_id + urlparse.quote(self.room_id), + self.user_id, ) # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"PUT", path, content.encode('ascii')) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger("GET", path, None) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"GET", path, None) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - expected_response = { - "membership": Membership.JOIN, - } - self.assertEquals(expected_response, response) + expected_response = {"membership": Membership.JOIN} + self.assertEquals(expected_response, channel.json_body) - @defer.inlineCallbacks def test_rooms_members_other(self): self.other_id = "@zzsid1:red" path = "/rooms/%s/state/m.room.member/%s" % ( - urllib.quote(self.room_id), self.other_id + urlparse.quote(self.room_id), + self.other_id, ) # valid invite message content = '{"membership":"%s"}' % Membership.INVITE - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"PUT", path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger("GET", path, None) - self.assertEquals(200, code, msg=str(response)) - self.assertEquals(json.loads(content), response) + request, channel = make_request(b"GET", path, None) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(json.loads(content), channel.json_body) - @defer.inlineCallbacks def test_rooms_members_other_custom_keys(self): self.other_id = "@zzsid1:red" path = "/rooms/%s/state/m.room.member/%s" % ( - urllib.quote(self.room_id), self.other_id + urlparse.quote(self.room_id), + self.other_id, ) # valid invite message with custom key - content = ('{"membership":"%s","invite_text":"%s"}' % ( - Membership.INVITE, "Join us!" - )) - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(200, code, msg=str(response)) + content = '{"membership":"%s","invite_text":"%s"}' % ( + Membership.INVITE, + "Join us!", + ) + request, channel = make_request(b"PUT", path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger("GET", path, None) - self.assertEquals(200, code, msg=str(response)) - self.assertEquals(json.loads(content), response) + request, channel = make_request(b"GET", path, None) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(json.loads(content), channel.json_body) -class RoomMessagesTestCase(RestTestCase): +class RoomMessagesTestCase(RoomBase): """ Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """ + user_id = "@sid1:red" - @defer.inlineCallbacks def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id + super(RoomMessagesTestCase, self).setUp() - hs = yield setup_test_homeserver( - "red", - http_client=None, - replication_layer=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) - - self.room_id = yield self.create_room_as(self.user_id) - - def tearDown(self): - pass + self.room_id = self.helper.create_room_as(self.user_id) - @defer.inlineCallbacks def test_invalid_puts(self): - path = "/rooms/%s/send/m.room.message/mid1" % ( - urllib.quote(self.room_id)) + path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '{}' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '{"_name":"bob"}' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{"_name":"bob"}') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '{"nao' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '{"nao') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' + request, channel = make_request( + b"PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' ) - self.assertEquals(400, code, msg=str(response)) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, 'text only' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, 'text only') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - (code, response) = yield self.mock_resource.trigger( - "PUT", path, '' - ) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, '') + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) - @defer.inlineCallbacks def test_rooms_messages_sent(self): - path = "/rooms/%s/send/m.room.message/mid1" % ( - urllib.quote(self.room_id)) + path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = '{"body":"test","msgtype":{"type":"a"}}' - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(400, code, msg=str(response)) + request, channel = make_request(b"PUT", path, content) + render(request, self.resource, self.clock) + self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) # custom message types content = '{"body":"test","msgtype":"test.custom.text"}' - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(200, code, msg=str(response)) - -# (code, response) = yield self.mock_resource.trigger("GET", path, None) -# self.assertEquals(200, code, msg=str(response)) -# self.assert_dict(json.loads(content), response) + request, channel = make_request(b"PUT", path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) # m.text message type - path = "/rooms/%s/send/m.room.message/mid2" % ( - urllib.quote(self.room_id)) + path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = '{"body":"test2","msgtype":"m.text"}' - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(200, code, msg=str(response)) + request, channel = make_request(b"PUT", path, content) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) -class RoomInitialSyncTestCase(RestTestCase): +class RoomInitialSyncTestCase(RoomBase): """ Tests /rooms/$room_id/initialSync. """ + user_id = "@sid1:red" - @defer.inlineCallbacks def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id - - hs = yield setup_test_homeserver( - "red", - http_client=None, - replication_layer=Mock(), - ratelimiter=NonCallableMock(spec_set=[ - "send_message", - ]), - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - hs.get_handlers().federation_handler = Mock() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + super(RoomInitialSyncTestCase, self).setUp() # create the room - self.room_id = yield self.create_room_as(self.user_id) + self.room_id = self.helper.create_room_as(self.user_id) - @defer.inlineCallbacks def test_initial_sync(self): - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/initialSync" % self.room_id - ) - self.assertEquals(200, code) + request, channel = make_request(b"GET", "/rooms/%s/initialSync" % self.room_id) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) - self.assertEquals(self.room_id, response["room_id"]) - self.assertEquals("join", response["membership"]) + self.assertEquals(self.room_id, channel.json_body["room_id"]) + self.assertEquals("join", channel.json_body["membership"]) # Room state is easier to assert on if we unpack it into a dict state = {} - for event in response["state"]: + for event in channel.json_body["state"]: if "state_key" not in event: continue t = event["type"] @@ -981,75 +800,48 @@ class RoomInitialSyncTestCase(RestTestCase): self.assertTrue("m.room.create" in state) - self.assertTrue("messages" in response) - self.assertTrue("chunk" in response["messages"]) - self.assertTrue("end" in response["messages"]) + self.assertTrue("messages" in channel.json_body) + self.assertTrue("chunk" in channel.json_body["messages"]) + self.assertTrue("end" in channel.json_body["messages"]) - self.assertTrue("presence" in response) + self.assertTrue("presence" in channel.json_body) presence_by_user = { - e["content"]["user_id"]: e for e in response["presence"] + e["content"]["user_id"]: e for e in channel.json_body["presence"] } self.assertTrue(self.user_id in presence_by_user) self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) -class RoomMessageListTestCase(RestTestCase): +class RoomMessageListTestCase(RoomBase): """ Tests /rooms/$room_id/messages REST events. """ + user_id = "@sid1:red" - @defer.inlineCallbacks def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.auth_user_id = self.user_id + super(RoomMessageListTestCase, self).setUp() + self.room_id = self.helper.create_room_as(self.user_id) - hs = yield setup_test_homeserver( - "red", - http_client=None, - replication_layer=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), + def test_topo_token_is_accepted(self): + token = "t1-0_0_0_0_0_0_0_0_0" + request, channel = make_request( + b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) + self.assertTrue("start" in channel.json_body) + self.assertEquals(token, channel.json_body['start']) + self.assertTrue("chunk" in channel.json_body) + self.assertTrue("end" in channel.json_body) - hs.get_handlers().federation_handler = Mock() - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token - - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) - hs.get_datastore().insert_client_ip = _insert_client_ip - - synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) - - self.room_id = yield self.create_room_as(self.user_id) - - @defer.inlineCallbacks - def test_topo_token_is_accepted(self): - token = "t1-0_0_0_0_0_0_0_0" - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/messages?access_token=x&from=%s" % - (self.room_id, token)) - self.assertEquals(200, code) - self.assertTrue("start" in response) - self.assertEquals(token, response['start']) - self.assertTrue("chunk" in response) - self.assertTrue("end" in response) - - @defer.inlineCallbacks def test_stream_token_is_accepted_for_fwd_pagianation(self): - token = "s0_0_0_0_0_0_0_0" - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/messages?access_token=x&from=%s" % - (self.room_id, token)) - self.assertEquals(200, code) - self.assertTrue("start" in response) - self.assertEquals(token, response['start']) - self.assertTrue("chunk" in response) - self.assertTrue("end" in response) + token = "s0_0_0_0_0_0_0_0_0" + request, channel = make_request( + b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + ) + render(request, self.resource, self.clock) + self.assertEquals(200, int(channel.result["code"])) + self.assertTrue("start" in channel.json_body) + self.assertEquals(token, channel.json_body['start']) + self.assertTrue("chunk" in channel.json_body) + self.assertTrue("end" in channel.json_body) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index a269e6f56e..bddb3302e4 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -15,18 +15,17 @@ """Tests REST events for /rooms paths.""" +from mock import Mock, NonCallableMock + # twisted imports from twisted.internet import defer import synapse.rest.client.v1.room from synapse.types import UserID -from ....utils import MockHttpResource, MockClock, setup_test_homeserver +from ....utils import MockClock, MockHttpResource, setup_test_homeserver from .utils import RestTestCase -from mock import Mock, NonCallableMock - - PATH_PREFIX = "/_matrix/client/api/v1" @@ -47,7 +46,7 @@ class RoomTypingTestCase(RestTestCase): "red", clock=self.clock, http_client=None, - replication_layer=Mock(), + federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=[ "send_message", ]), @@ -68,7 +67,7 @@ class RoomTypingTestCase(RestTestCase): "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -95,7 +94,7 @@ class RoomTypingTestCase(RestTestCase): else: if remotedomains is not None: remotedomains.add(member.domain) - hs.get_handlers().room_member_handler.fetch_room_distributions_into = ( + hs.get_room_member_handler().fetch_room_distributions_into = ( fetch_room_distributions_into ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 3bb1dd003a..41de8e0762 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -13,16 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -# twisted imports -from twisted.internet import defer +import json +import time -# trial imports -from tests import unittest +import attr + +from twisted.internet import defer from synapse.api.constants import Membership -import json -import time +from tests import unittest +from tests.server import make_request, wait_until_result class RestTestCase(unittest.TestCase): @@ -133,3 +134,113 @@ class RestTestCase(unittest.TestCase): for key in required: self.assertEquals(required[key], actual[key], msg="%s mismatch. %s" % (key, actual)) + + +@attr.s +class RestHelper(object): + """Contains extra helper functions to quickly and clearly perform a given + REST action, which isn't the focus of the test. + """ + + hs = attr.ib() + resource = attr.ib() + auth_user_id = attr.ib() + + def create_room_as(self, room_creator, is_public=True, tok=None): + temp_id = self.auth_user_id + self.auth_user_id = room_creator + path = b"/_matrix/client/r0/createRoom" + content = {} + if not is_public: + content["visibility"] = "private" + if tok: + path = path + b"?access_token=%s" % tok.encode('ascii') + + request, channel = make_request(b"POST", path, json.dumps(content).encode('utf8')) + request.render(self.resource) + wait_until_result(self.hs.get_reactor(), channel) + + assert channel.result["code"] == b"200", channel.result + self.auth_user_id = temp_id + return channel.json_body["room_id"] + + def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=src, + targ=targ, + tok=tok, + membership=Membership.INVITE, + expect_code=expect_code, + ) + + def join(self, room=None, user=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=user, + targ=user, + tok=tok, + membership=Membership.JOIN, + expect_code=expect_code, + ) + + def leave(self, room=None, user=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=user, + targ=user, + tok=tok, + membership=Membership.LEAVE, + expect_code=expect_code, + ) + + def change_membership(self, room, src, targ, membership, tok=None, expect_code=200): + temp_id = self.auth_user_id + self.auth_user_id = src + + path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ) + if tok: + path = path + "?access_token=%s" % tok + + data = {"membership": membership} + + request, channel = make_request( + b"PUT", path.encode('ascii'), json.dumps(data).encode('utf8') + ) + + request.render(self.resource) + wait_until_result(self.hs.get_reactor(), channel) + + assert int(channel.result["code"]) == expect_code, ( + "Expected: %d, got: %d, resp: %r" + % (expect_code, int(channel.result["code"]), channel.result["body"]) + ) + + self.auth_user_id = temp_id + + @defer.inlineCallbacks + def register(self, user_id): + (code, response) = yield self.mock_resource.trigger( + "POST", + "/_matrix/client/r0/register", + json.dumps( + {"user": user_id, "password": "test", "type": "m.login.password"} + ), + ) + self.assertEquals(200, code) + defer.returnValue(response) + + @defer.inlineCallbacks + def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): + if txn_id is None: + txn_id = "m%s" % (str(time.time())) + if body is None: + body = "body_text_here" + + path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id) + content = '{"msgtype":"m.text","body":"%s"}' % body + if tok: + path = path + "?access_token=%s" % tok + + (code, response) = yield self.mock_resource.trigger("PUT", path, content) + self.assertEquals(expect_code, code, msg=str(response)) diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index 5170217d9e..e69de29bb2 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -1,62 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from tests import unittest - -from mock import Mock - -from ....utils import MockHttpResource, setup_test_homeserver - -from synapse.types import UserID - -from twisted.internet import defer - - -PATH_PREFIX = "/_matrix/client/v2_alpha" - - -class V2AlphaRestTestCase(unittest.TestCase): - # Consumer must define - # USER_ID = <some string> - # TO_REGISTER = [<list of REST servlets to register>] - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - - hs = yield setup_test_homeserver( - datastore=self.make_datastore_mock(), - http_client=None, - resource_for_client=self.mock_resource, - resource_for_federation=self.mock_resource, - ) - - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.USER_ID), - "token_id": 1, - "is_guest": False, - } - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - for r in self.TO_REGISTER: - r.register_servlets(hs, self.mock_resource) - - def make_datastore_mock(self): - store = Mock(spec=[ - "insert_client_ip", - ]) - store.get_app_service_by_token = Mock(return_value=None) - return store diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index 76b833e119..e890f0feac 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -13,38 +13,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from tests import unittest - -from synapse.rest.client.v2_alpha import filter - -from synapse.api.errors import Codes - import synapse.types - +from synapse.api.errors import Codes +from synapse.http.server import JsonResource +from synapse.rest.client.v2_alpha import filter from synapse.types import UserID +from synapse.util import Clock -from ....utils import MockHttpResource, setup_test_homeserver +from tests import unittest +from tests.server import ( + ThreadedMemoryReactorClock as MemoryReactorClock, + make_request, + setup_test_homeserver, + wait_until_result, +) PATH_PREFIX = "/_matrix/client/v2_alpha" class FilterTestCase(unittest.TestCase): - USER_ID = "@apple:test" + USER_ID = b"@apple:test" EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} - EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}' + EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' TO_REGISTER = [filter] - @defer.inlineCallbacks def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) + self.clock = MemoryReactorClock() + self.hs_clock = Clock(self.clock) - self.hs = yield setup_test_homeserver( - http_client=None, - resource_for_client=self.mock_resource, - resource_for_federation=self.mock_resource, + self.hs = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.clock ) self.auth = self.hs.get_auth() @@ -58,82 +57,103 @@ class FilterTestCase(unittest.TestCase): def get_user_by_req(request, allow_guest=False, rights="access"): return synapse.types.create_requester( - UserID.from_string(self.USER_ID), 1, False, None) + UserID.from_string(self.USER_ID), 1, False, None + ) self.auth.get_user_by_access_token = get_user_by_access_token self.auth.get_user_by_req = get_user_by_req self.store = self.hs.get_datastore() self.filtering = self.hs.get_filtering() + self.resource = JsonResource(self.hs) for r in self.TO_REGISTER: - r.register_servlets(self.hs, self.mock_resource) + r.register_servlets(self.hs, self.resource) - @defer.inlineCallbacks def test_add_filter(self): - (code, response) = yield self.mock_resource.trigger( - "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON - ) - self.assertEquals(200, code) - self.assertEquals({"filter_id": "0"}, response) - filter = yield self.store.get_user_filter( - user_localpart='apple', - filter_id=0, + request, channel = make_request( + b"POST", + b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), + self.EXAMPLE_FILTER_JSON, ) - self.assertEquals(filter, self.EXAMPLE_FILTER) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.json_body, {"filter_id": "0"}) + filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) + self.clock.advance(0) + self.assertEquals(filter.result, self.EXAMPLE_FILTER) - @defer.inlineCallbacks def test_add_filter_for_other_user(self): - (code, response) = yield self.mock_resource.trigger( - "POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON + request, channel = make_request( + b"POST", + b"/_matrix/client/r0/user/%s/filter" % (b"@watermelon:test"), + self.EXAMPLE_FILTER_JSON, ) - self.assertEquals(403, code) - self.assertEquals(response['errcode'], Codes.FORBIDDEN) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"403") + self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) - @defer.inlineCallbacks def test_add_filter_non_local_user(self): _is_mine = self.hs.is_mine self.hs.is_mine = lambda target_user: False - (code, response) = yield self.mock_resource.trigger( - "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON + request, channel = make_request( + b"POST", + b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), + self.EXAMPLE_FILTER_JSON, ) + request.render(self.resource) + wait_until_result(self.clock, channel) + self.hs.is_mine = _is_mine - self.assertEquals(403, code) - self.assertEquals(response['errcode'], Codes.FORBIDDEN) + self.assertEqual(channel.result["code"], b"403") + self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) - @defer.inlineCallbacks def test_get_filter(self): - filter_id = yield self.filtering.add_user_filter( - user_localpart='apple', - user_filter=self.EXAMPLE_FILTER + filter_id = self.filtering.add_user_filter( + user_localpart="apple", user_filter=self.EXAMPLE_FILTER ) - (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/%s" % (self.USER_ID, filter_id) + self.clock.advance(1) + filter_id = filter_id.result + request, channel = make_request( + b"GET", b"/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id) ) - self.assertEquals(200, code) - self.assertEquals(self.EXAMPLE_FILTER, response) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"200") + self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) - @defer.inlineCallbacks def test_get_filter_non_existant(self): - (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/12382148321" % (self.USER_ID) + request, channel = make_request( + b"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID) ) - self.assertEquals(400, code) - self.assertEquals(response['errcode'], Codes.NOT_FOUND) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"400") + self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) # Currently invalid params do not have an appropriate errcode # in errors.py - @defer.inlineCallbacks def test_get_filter_invalid_id(self): - (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/foobar" % (self.USER_ID) + request, channel = make_request( + b"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID) ) - self.assertEquals(400, code) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"400") # No ID also returns an invalid_id error - @defer.inlineCallbacks def test_get_filter_no_id(self): - (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/" % (self.USER_ID) + request, channel = make_request( + b"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID) ) - self.assertEquals(400, code) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"400") diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index b6173ab2ee..e004d8fc73 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -1,158 +1,193 @@ -from synapse.rest.client.v2_alpha.register import RegisterRestServlet -from synapse.api.errors import SynapseError -from twisted.internet import defer +import json + from mock import Mock + +from twisted.python import failure +from twisted.test.proto_helpers import MemoryReactorClock + +from synapse.api.errors import InteractiveAuthIncompleteError +from synapse.http.server import JsonResource +from synapse.rest.client.v2_alpha.register import register_servlets +from synapse.util import Clock + from tests import unittest -from tests.utils import mock_getRawHeaders -import json +from tests.server import make_request, setup_test_homeserver, wait_until_result class RegisterRestServletTestCase(unittest.TestCase): - def setUp(self): - # do the dance to hook up request data to self.request_data - self.request_data = "" - self.request = Mock( - content=Mock(read=Mock(side_effect=lambda: self.request_data)), - path='/_matrix/api/v2_alpha/register' - ) - self.request.args = {} - self.request.requestHeaders.getRawHeaders = mock_getRawHeaders() + + self.clock = MemoryReactorClock() + self.hs_clock = Clock(self.clock) + self.url = b"/_matrix/client/r0/register" self.appservice = None - self.auth = Mock(get_appservice_by_req=Mock( - side_effect=lambda x: self.appservice) + self.auth = Mock( + get_appservice_by_req=Mock(side_effect=lambda x: self.appservice) ) - self.auth_result = (False, None, None, None) + self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None)) self.auth_handler = Mock( check_auth=Mock(side_effect=lambda x, y, z: self.auth_result), - get_session_data=Mock(return_value=None) + get_session_data=Mock(return_value=None), ) self.registration_handler = Mock() self.identity_handler = Mock() self.login_handler = Mock() self.device_handler = Mock() + self.device_handler.check_device_registered = Mock(return_value="FAKE") + + self.datastore = Mock(return_value=Mock()) + self.datastore.get_current_state_deltas = Mock(return_value=[]) # do the dance to hook it up to the hs global self.handlers = Mock( registration_handler=self.registration_handler, identity_handler=self.identity_handler, - login_handler=self.login_handler + login_handler=self.login_handler, + ) + self.hs = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.clock ) - self.hs = Mock() - self.hs.hostname = "superbig~testing~thing.com" self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_auth_handler = Mock(return_value=self.auth_handler) self.hs.get_device_handler = Mock(return_value=self.device_handler) + self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.config.enable_registration = True + self.hs.config.registrations_require_3pid = [] + self.hs.config.auto_join_rooms = [] - # init the thing we're testing - self.servlet = RegisterRestServlet(self.hs) + self.resource = JsonResource(self.hs) + register_servlets(self.hs, self.resource) - @defer.inlineCallbacks def test_POST_appservice_registration_valid(self): user_id = "@kermit:muppet" token = "kermits_access_token" - self.request.args = { - "access_token": "i_am_an_app_service" - } - self.request_data = json.dumps({ - "username": "kermit" - }) - self.appservice = { - "id": "1234" - } - self.registration_handler.appservice_register = Mock( - return_value=user_id - ) - self.auth_handler.get_access_token_for_user_id = Mock( - return_value=token + self.appservice = {"id": "1234"} + self.registration_handler.appservice_register = Mock(return_value=user_id) + self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) + request_data = json.dumps({"username": "kermit"}) + + request, channel = make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) + request.render(self.resource) + wait_until_result(self.clock, channel) - (code, result) = yield self.servlet.on_POST(self.request) - self.assertEquals(code, 200) + self.assertEquals(channel.result["code"], b"200", channel.result) det_data = { "user_id": user_id, "access_token": token, - "home_server": self.hs.hostname + "home_server": self.hs.hostname, } - self.assertDictContainsSubset(det_data, result) + self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) - @defer.inlineCallbacks def test_POST_appservice_registration_invalid(self): - self.request.args = { - "access_token": "i_am_an_app_service" - } - self.request_data = json.dumps({ - "username": "kermit" - }) self.appservice = None # no application service exists - result = yield self.servlet.on_POST(self.request) - self.assertEquals(result, (401, None)) + request_data = json.dumps({"username": "kermit"}) + request, channel = make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data + ) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEquals(channel.result["code"], b"401", channel.result) def test_POST_bad_password(self): - self.request_data = json.dumps({ - "username": "kermit", - "password": 666 - }) - d = self.servlet.on_POST(self.request) - return self.assertFailure(d, SynapseError) + request_data = json.dumps({"username": "kermit", "password": 666}) + request, channel = make_request(b"POST", self.url, request_data) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEquals( + json.loads(channel.result["body"])["error"], "Invalid password" + ) def test_POST_bad_username(self): - self.request_data = json.dumps({ - "username": 777, - "password": "monkey" - }) - d = self.servlet.on_POST(self.request) - return self.assertFailure(d, SynapseError) - - @defer.inlineCallbacks + request_data = json.dumps({"username": 777, "password": "monkey"}) + request, channel = make_request(b"POST", self.url, request_data) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEquals( + json.loads(channel.result["body"])["error"], "Invalid username" + ) + def test_POST_user_valid(self): user_id = "@kermit:muppet" token = "kermits_access_token" device_id = "frogfone" - self.request_data = json.dumps({ - "username": "kermit", - "password": "monkey", - "device_id": device_id, - }) + request_data = json.dumps( + {"username": "kermit", "password": "monkey", "device_id": device_id} + ) self.registration_handler.check_username = Mock(return_value=True) - self.auth_result = (True, None, { - "username": "kermit", - "password": "monkey" - }, None) + self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) self.registration_handler.register = Mock(return_value=(user_id, None)) - self.auth_handler.get_access_token_for_user_id = Mock( - return_value=token - ) - self.device_handler.check_device_registered = \ - Mock(return_value=device_id) + self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) + self.device_handler.check_device_registered = Mock(return_value=device_id) + + request, channel = make_request(b"POST", self.url, request_data) + request.render(self.resource) + wait_until_result(self.clock, channel) - (code, result) = yield self.servlet.on_POST(self.request) - self.assertEquals(code, 200) det_data = { "user_id": user_id, "access_token": token, "home_server": self.hs.hostname, "device_id": device_id, } - self.assertDictContainsSubset(det_data, result) + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) self.auth_handler.get_login_tuple_for_user_id( - user_id, device_id=device_id, initial_device_display_name=None) + user_id, device_id=device_id, initial_device_display_name=None + ) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False - self.request_data = json.dumps({ - "username": "kermit", - "password": "monkey" - }) + request_data = json.dumps({"username": "kermit", "password": "monkey"}) self.registration_handler.check_username = Mock(return_value=True) - self.auth_result = (True, None, { - "username": "kermit", - "password": "monkey" - }, None) + self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) self.registration_handler.register = Mock(return_value=("@user:id", "t")) - d = self.servlet.on_POST(self.request) - return self.assertFailure(d, SynapseError) + + request, channel = make_request(b"POST", self.url, request_data) + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals( + json.loads(channel.result["body"])["error"], + "Registration has been disabled", + ) + + def test_POST_guest_registration(self): + user_id = "a@b" + self.hs.config.macaroon_secret_key = "test" + self.hs.config.allow_guest_access = True + self.registration_handler.register = Mock(return_value=(user_id, None)) + + request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}") + request.render(self.resource) + wait_until_result(self.clock, channel) + + det_data = { + "user_id": user_id, + "home_server": self.hs.hostname, + "device_id": "guest_device", + } + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, json.loads(channel.result["body"])) + + def test_POST_disabled_guest_registration(self): + self.hs.config.allow_guest_access = False + + request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}") + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals( + json.loads(channel.result["body"])["error"], "Guest access is disabled" + ) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py new file mode 100644 index 0000000000..03ec3993b2 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse.types +from synapse.http.server import JsonResource +from synapse.rest.client.v2_alpha import sync +from synapse.types import UserID +from synapse.util import Clock + +from tests import unittest +from tests.server import ( + ThreadedMemoryReactorClock as MemoryReactorClock, + make_request, + setup_test_homeserver, + wait_until_result, +) + +PATH_PREFIX = "/_matrix/client/v2_alpha" + + +class FilterTestCase(unittest.TestCase): + + USER_ID = b"@apple:test" + TO_REGISTER = [sync] + + def setUp(self): + self.clock = MemoryReactorClock() + self.hs_clock = Clock(self.clock) + + self.hs = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.clock + ) + + self.auth = self.hs.get_auth() + + def get_user_by_access_token(token=None, allow_guest=False): + return { + "user": UserID.from_string(self.USER_ID), + "token_id": 1, + "is_guest": False, + } + + def get_user_by_req(request, allow_guest=False, rights="access"): + return synapse.types.create_requester( + UserID.from_string(self.USER_ID), 1, False, None + ) + + self.auth.get_user_by_access_token = get_user_by_access_token + self.auth.get_user_by_req = get_user_by_req + + self.store = self.hs.get_datastore() + self.filtering = self.hs.get_filtering() + self.resource = JsonResource(self.hs) + + for r in self.TO_REGISTER: + r.register_servlets(self.hs, self.resource) + + def test_sync_argless(self): + request, channel = make_request(b"GET", b"/_matrix/client/r0/sync") + request.render(self.resource) + wait_until_result(self.clock, channel) + + self.assertEqual(channel.result["code"], b"200") + self.assertTrue( + set( + [ + "next_batch", + "rooms", + "presence", + "account_data", + "to_device", + "device_lists", + ] + ).issubset(set(channel.json_body.keys())) + ) diff --git a/tests/rest/media/__init__.py b/tests/rest/media/__init__.py new file mode 100644 index 0000000000..a354d38ca8 --- /dev/null +++ b/tests/rest/media/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/rest/media/v1/__init__.py b/tests/rest/media/v1/__init__.py new file mode 100644 index 0000000000..a354d38ca8 --- /dev/null +++ b/tests/rest/media/v1/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py new file mode 100644 index 0000000000..bf254a260d --- /dev/null +++ b/tests/rest/media/v1/test_media_storage.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shutil +import tempfile + +from mock import Mock + +from twisted.internet import defer, reactor + +from synapse.rest.media.v1._base import FileInfo +from synapse.rest.media.v1.filepath import MediaFilePaths +from synapse.rest.media.v1.media_storage import MediaStorage +from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend + +from tests import unittest + + +class MediaStorageTests(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") + + self.primary_base_path = os.path.join(self.test_dir, "primary") + self.secondary_base_path = os.path.join(self.test_dir, "secondary") + + hs = Mock() + hs.get_reactor = Mock(return_value=reactor) + hs.config.media_store_path = self.primary_base_path + + storage_providers = [FileStorageProviderBackend( + hs, self.secondary_base_path + )] + + self.filepaths = MediaFilePaths(self.primary_base_path) + self.media_storage = MediaStorage( + hs, self.primary_base_path, self.filepaths, storage_providers, + ) + + def tearDown(self): + shutil.rmtree(self.test_dir) + + @defer.inlineCallbacks + def test_ensure_media_is_in_local_cache(self): + media_id = "some_media_id" + test_body = "Test\n" + + # First we create a file that is in a storage provider but not in the + # local primary media store + rel_path = self.filepaths.local_media_filepath_rel(media_id) + secondary_path = os.path.join(self.secondary_base_path, rel_path) + + os.makedirs(os.path.dirname(secondary_path)) + + with open(secondary_path, "w") as f: + f.write(test_body) + + # Now we run ensure_media_is_in_local_cache, which should copy the file + # to the local cache. + file_info = FileInfo(None, media_id) + local_path = yield self.media_storage.ensure_media_is_in_local_cache(file_info) + + self.assertTrue(os.path.exists(local_path)) + + # Asserts the file is under the expected local cache directory + self.assertEquals( + os.path.commonprefix([self.primary_base_path, local_path]), + self.primary_base_path, + ) + + with open(local_path) as f: + body = f.read() + + self.assertEqual(test_body, body) diff --git a/tests/server.py b/tests/server.py new file mode 100644 index 0000000000..c611dd6059 --- /dev/null +++ b/tests/server.py @@ -0,0 +1,193 @@ +import json +from io import BytesIO + +from six import text_type + +import attr + +from twisted.internet import threads +from twisted.internet.defer import Deferred +from twisted.python.failure import Failure +from twisted.test.proto_helpers import MemoryReactorClock + +from synapse.http.site import SynapseRequest + +from tests.utils import setup_test_homeserver as _sth + + +@attr.s +class FakeChannel(object): + """ + A fake Twisted Web Channel (the part that interfaces with the + wire). + """ + + result = attr.ib(default=attr.Factory(dict)) + + @property + def json_body(self): + if not self.result: + raise Exception("No result yet.") + return json.loads(self.result["body"]) + + def writeHeaders(self, version, code, reason, headers): + self.result["version"] = version + self.result["code"] = code + self.result["reason"] = reason + self.result["headers"] = headers + + def write(self, content): + if "body" not in self.result: + self.result["body"] = b"" + + self.result["body"] += content + + def requestDone(self, _self): + self.result["done"] = True + + def getPeer(self): + return None + + def getHost(self): + return None + + @property + def transport(self): + return self + + +class FakeSite: + """ + A fake Twisted Web Site, with mocks of the extra things that + Synapse adds. + """ + + server_version_string = b"1" + site_tag = "test" + + @property + def access_logger(self): + class FakeLogger: + def info(self, *args, **kwargs): + pass + + return FakeLogger() + + +def make_request(method, path, content=b""): + """ + Make a web request using the given method and path, feed it the + content, and return the Request and the Channel underneath. + """ + + # Decorate it to be the full path + if not path.startswith(b"/_matrix"): + path = b"/_matrix/client/r0/" + path + path = path.replace("//", "/") + + if isinstance(content, text_type): + content = content.encode('utf8') + + site = FakeSite() + channel = FakeChannel() + + req = SynapseRequest(site, channel) + req.process = lambda: b"" + req.content = BytesIO(content) + req.requestReceived(method, path, b"1.1") + + return req, channel + + +def wait_until_result(clock, channel, timeout=100): + """ + Wait until the channel has a result. + """ + clock.run() + x = 0 + + while not channel.result: + x += 1 + + if x > timeout: + raise Exception("Timed out waiting for request to finish.") + + clock.advance(0.1) + + +def render(request, resource, clock): + request.render(resource) + wait_until_result(clock, request._channel) + + +class ThreadedMemoryReactorClock(MemoryReactorClock): + """ + A MemoryReactorClock that supports callFromThread. + """ + def callFromThread(self, callback, *args, **kwargs): + """ + Make the callback fire in the next reactor iteration. + """ + d = Deferred() + d.addCallback(lambda x: callback(*args, **kwargs)) + self.callLater(0, d.callback, True) + return d + + +def setup_test_homeserver(*args, **kwargs): + """ + Set up a synchronous test server, driven by the reactor used by + the homeserver. + """ + d = _sth(*args, **kwargs).result + + # Make the thread pool synchronous. + clock = d.get_clock() + pool = d.get_db_pool() + + def runWithConnection(func, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runWithConnection, + func, + *args, + **kwargs + ) + + def runInteraction(interaction, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runInteraction, + interaction, + *args, + **kwargs + ) + + pool.runWithConnection = runWithConnection + pool.runInteraction = runInteraction + + class ThreadPool: + """ + Threadless thread pool. + """ + def start(self): + pass + + def callInThreadWithCallback(self, onResult, function, *args, **kwargs): + def _(res): + if isinstance(res, Failure): + onResult(False, res) + else: + onResult(True, res) + + d = Deferred() + d.addCallback(lambda x: function(*args, **kwargs)) + d.addBoth(_) + clock._reactor.callLater(0, d.callback, True) + return d + + clock.threadpool = ThreadPool() + pool.threadpool = ThreadPool() + return d diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py deleted file mode 100644 index 024ac15069..0000000000 --- a/tests/storage/event_injector.py +++ /dev/null @@ -1,76 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet import defer - -from synapse.api.constants import EventTypes - - -class EventInjector: - def __init__(self, hs): - self.hs = hs - self.store = hs.get_datastore() - self.message_handler = hs.get_handlers().message_handler - self.event_builder_factory = hs.get_event_builder_factory() - - @defer.inlineCallbacks - def create_room(self, room, user): - builder = self.event_builder_factory.new({ - "type": EventTypes.Create, - "sender": user.to_string(), - "room_id": room.to_string(), - "content": {}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) - - @defer.inlineCallbacks - def inject_room_member(self, room, user, membership): - builder = self.event_builder_factory.new({ - "type": EventTypes.Member, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"membership": membership}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) - - defer.returnValue(event) - - @defer.inlineCallbacks - def inject_message(self, room, user, body): - builder = self.event_builder_factory.new({ - "type": EventTypes.Message, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"body": body, "msgtype": u"message"}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 3cfa21c9f8..6d6f00c5c5 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -14,15 +14,15 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer - from mock import Mock -from synapse.util.async import ObservableDeferred +from twisted.internet import defer +from synapse.util.async import ObservableDeferred from synapse.util.caches.descriptors import Cache, cached +from tests import unittest + class CacheTestCase(unittest.TestCase): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 9e98d0e330..099861b27c 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -12,21 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json +import os import tempfile -from synapse.config._base import ConfigError -from tests import unittest + +from mock import Mock + +import yaml + from twisted.internet import defer -from tests.utils import setup_test_homeserver from synapse.appservice import ApplicationService, ApplicationServiceState +from synapse.config._base import ConfigError from synapse.storage.appservice import ( - ApplicationServiceStore, ApplicationServiceTransactionStore + ApplicationServiceStore, + ApplicationServiceTransactionStore, ) -import json -import os -import yaml -from mock import Mock +from tests import unittest +from tests.utils import setup_test_homeserver class ApplicationServiceStoreTestCase(unittest.TestCase): @@ -42,7 +46,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): hs = yield setup_test_homeserver( config=config, federation_sender=Mock(), - replication_layer=Mock(), + federation_client=Mock(), ) self.as_token = "token1" @@ -58,14 +62,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - self.store = ApplicationServiceStore(hs) + self.store = ApplicationServiceStore(None, hs) def tearDown(self): # TODO: suboptimal that we need to create files for tests! for f in self.as_yaml_files: try: os.remove(f) - except: + except Exception: pass def _add_appservice(self, as_token, id, url, hs_token, sender): @@ -119,7 +123,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): hs = yield setup_test_homeserver( config=config, federation_sender=Mock(), - replication_layer=Mock(), + federation_client=Mock(), ) self.db_pool = hs.get_db_pool() @@ -150,7 +154,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] - self.store = TestTransactionStore(hs) + self.store = TestTransactionStore(None, hs) def _add_service(self, url, as_token, id): as_yaml = dict(url=url, as_token=as_token, hs_token="something", @@ -420,8 +424,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, hs): - super(TestTransactionStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(TestTransactionStore, self).__init__(db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.TestCase): @@ -455,10 +459,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): config=config, datastore=Mock(), federation_sender=Mock(), - replication_layer=Mock(), + federation_client=Mock(), ) - ApplicationServiceStore(hs) + ApplicationServiceStore(None, hs) @defer.inlineCallbacks def test_duplicate_ids(self): @@ -473,16 +477,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): config=config, datastore=Mock(), federation_sender=Mock(), - replication_layer=Mock(), + federation_client=Mock(), ) with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs) + ApplicationServiceStore(None, hs) e = cm.exception - self.assertIn(f1, e.message) - self.assertIn(f2, e.message) - self.assertIn("id", e.message) + self.assertIn(f1, str(e)) + self.assertIn(f2, str(e)) + self.assertIn("id", str(e)) @defer.inlineCallbacks def test_duplicate_as_tokens(self): @@ -497,13 +501,13 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): config=config, datastore=Mock(), federation_sender=Mock(), - replication_layer=Mock(), + federation_client=Mock(), ) with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs) + ApplicationServiceStore(None, hs) e = cm.exception - self.assertIn(f1, e.message) - self.assertIn(f2, e.message) - self.assertIn("as_token", e.message) + self.assertIn(f1, str(e)) + self.assertIn(f2, str(e)) + self.assertIn("as_token", str(e)) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 1286b4ce2d..ab1f310572 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -1,10 +1,10 @@ -from tests import unittest +from mock import Mock + from twisted.internet import defer +from tests import unittest from tests.utils import setup_test_homeserver -from mock import Mock - class BackgroundUpdateTestCase(unittest.TestCase): diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 91e971190c..1d1234ee39 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -14,18 +14,18 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer +from collections import OrderedDict from mock import Mock -from collections import OrderedDict +from twisted.internet import defer from synapse.server import HomeServer - from synapse.storage._base import SQLBaseStore from synapse.storage.engines import create_engine +from tests import unittest + class SQLBaseStoreTestCase(unittest.TestCase): """ Test the "simple" SQL generating methods in SQLBaseStore. """ @@ -56,7 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): database_engine=create_engine(config.database_config), ) - self.datastore = SQLBaseStore(hs) + self.datastore = SQLBaseStore(None, hs) @defer.inlineCallbacks def test_insert_1col(self): diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 03df697575..bd6fda6cb1 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -15,9 +15,6 @@ from twisted.internet import defer -import synapse.server -import synapse.storage -import synapse.types import tests.unittest import tests.utils @@ -39,7 +36,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.clock.now = 12345678 user_id = "@user:id" yield self.store.insert_client_ip( - synapse.types.UserID.from_string(user_id), + user_id, "access_token", "ip", "user_agent", "device_id", ) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index f8725acea0..a54cc6bc32 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -16,6 +16,7 @@ from twisted.internet import defer import synapse.api.errors + import tests.unittest import tests.utils diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index b087892e0b..129ebaf343 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -14,12 +14,12 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer from synapse.storage.directory import DirectoryStore -from synapse.types import RoomID, RoomAlias +from synapse.types import RoomAlias, RoomID +from tests import unittest from tests.utils import setup_test_homeserver @@ -29,7 +29,7 @@ class DirectoryStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver() - self.store = DirectoryStore(hs) + self.store = DirectoryStore(None, hs) self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#my-room:test") diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py new file mode 100644 index 0000000000..30683e7888 --- /dev/null +++ b/tests/storage/test_event_federation.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import tests.unittest +import tests.utils + + +class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def test_get_prev_events_for_room(self): + room_id = '@ROOM:local' + + # add a bunch of events and hashes to act as forward extremities + def insert_event(txn, i): + event_id = '$event_%i:local' % i + + txn.execute(( + "INSERT INTO events (" + " room_id, event_id, type, depth, topological_ordering," + " content, processed, outlier) " + "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)" + ), (room_id, event_id, i, i, True, False)) + + txn.execute(( + 'INSERT INTO event_forward_extremities (room_id, event_id) ' + 'VALUES (?, ?)' + ), (room_id, event_id)) + + txn.execute(( + 'INSERT INTO event_reference_hashes ' + '(event_id, algorithm, hash) ' + "VALUES (?, 'sha256', ?)" + ), (event_id, 'ffff')) + + for i in range(0, 11): + yield self.store.runInteraction("insert", insert_event, i) + + # this should get the last five and five others + r = yield self.store.get_prev_events_for_room(room_id) + self.assertEqual(10, len(r)) + for i in range(0, 5): + el = r[i] + depth = el[2] + self.assertEqual(10 - i, depth) + + for i in range(5, 5): + el = r[i] + depth = el[2] + self.assertLessEqual(5, depth) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 3135488353..8430fc7ba6 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from twisted.internet import defer import tests.unittest import tests.utils -from mock import Mock USER_ID = "@user:example.com" @@ -55,13 +56,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): def _assert_counts(noitf_count, highlight_count): counts = yield self.store.runInteraction( "", self.store._get_unread_counts_by_pos_txn, - room_id, user_id, 0, 0 + room_id, user_id, 0 ) self.assertEquals( counts, {"notify_count": noitf_count, "highlight_count": highlight_count} ) + @defer.inlineCallbacks def _inject_actions(stream, action): event = Mock() event.room_id = room_id @@ -69,11 +71,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): event.internal_metadata.stream_ordering = stream event.depth = stream - tuples = [(user_id, action)] - - return self.store.runInteraction( + yield self.store.add_push_actions_to_staging( + event.event_id, {user_id: action}, + ) + yield self.store.runInteraction( "", self.store._set_push_actions_for_event_and_users_txn, - event, tuples + [(event, None)], [(event, None)], ) def _rotate(stream): @@ -84,7 +87,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): def _mark_read(stream, depth): return self.store.runInteraction( "", self.store._remove_old_push_actions_before_txn, - room_id, user_id, depth, stream + room_id, user_id, stream ) yield _assert_counts(0, 0) @@ -125,3 +128,69 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _assert_counts(1, 1) yield _rotate(10) yield _assert_counts(1, 1) + + @defer.inlineCallbacks + def test_find_first_stream_ordering_after_ts(self): + def add_event(so, ts): + return self.store._simple_insert("events", { + "stream_ordering": so, + "received_ts": ts, + "event_id": "event%i" % so, + "type": "", + "room_id": "", + "content": "", + "processed": True, + "outlier": False, + "topological_ordering": 0, + "depth": 0, + }) + + # start with the base case where there are no events in the table + r = yield self.store.find_first_stream_ordering_after_ts(11) + self.assertEqual(r, 0) + + # now with one event + yield add_event(2, 10) + r = yield self.store.find_first_stream_ordering_after_ts(9) + self.assertEqual(r, 2) + r = yield self.store.find_first_stream_ordering_after_ts(10) + self.assertEqual(r, 2) + r = yield self.store.find_first_stream_ordering_after_ts(11) + self.assertEqual(r, 3) + + # add a bunch of dummy events to the events table + for (stream_ordering, ts) in ( + (3, 110), + (4, 120), + (5, 120), + (10, 130), + (20, 140), + ): + yield add_event(stream_ordering, ts) + + r = yield self.store.find_first_stream_ordering_after_ts(110) + self.assertEqual(r, 3, + "First event after 110ms should be 3, was %i" % r) + + # 4 and 5 are both after 120: we want 4 rather than 5 + r = yield self.store.find_first_stream_ordering_after_ts(120) + self.assertEqual(r, 4, + "First event after 120ms should be 4, was %i" % r) + + r = yield self.store.find_first_stream_ordering_after_ts(129) + self.assertEqual(r, 10, + "First event after 129ms should be 10, was %i" % r) + + # check we can get the last event + r = yield self.store.find_first_stream_ordering_after_ts(140) + self.assertEqual(r, 20, + "First event after 14ms should be 20, was %i" % r) + + # off the end + r = yield self.store.find_first_stream_ordering_after_ts(160) + self.assertEqual(r, 21) + + # check we can find an event at ordering zero + yield add_event(0, 5) + r = yield self.store.find_first_stream_ordering_after_ts(1) + self.assertEqual(r, 0) diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index 0be790d8f8..3a3d002782 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -14,6 +14,7 @@ # limitations under the License. import signedjson.key + from twisted.internet import defer import tests.unittest diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py index 63203cea35..3276b39504 100644 --- a/tests/storage/test_presence.py +++ b/tests/storage/test_presence.py @@ -14,13 +14,13 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer from synapse.storage.presence import PresenceStore from synapse.types import UserID -from tests.utils import setup_test_homeserver, MockClock +from tests import unittest +from tests.utils import MockClock, setup_test_homeserver class PresenceStoreTestCase(unittest.TestCase): @@ -29,7 +29,7 @@ class PresenceStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver(clock=MockClock()) - self.store = PresenceStore(hs) + self.store = PresenceStore(None, hs) self.u_apple = UserID.from_string("@apple:test") self.u_banana = UserID.from_string("@banana:test") diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 24118bbc86..2c95e5e95a 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -14,12 +14,12 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer from synapse.storage.profile import ProfileStore from synapse.types import UserID +from tests import unittest from tests.utils import setup_test_homeserver @@ -29,7 +29,7 @@ class ProfileStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver() - self.store = ProfileStore(hs) + self.store = ProfileStore(None, hs) self.u_frank = UserID.from_string("@frank:test") diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 6afaca3a61..475ec900c4 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -14,16 +14,16 @@ # limitations under the License. -from tests import unittest +from mock import Mock + from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.types import UserID, RoomID +from synapse.types import RoomID, UserID +from tests import unittest from tests.utils import setup_test_homeserver -from mock import Mock - class RedactionTestCase(unittest.TestCase): @@ -36,8 +36,7 @@ class RedactionTestCase(unittest.TestCase): self.store = hs.get_datastore() self.event_builder_factory = hs.get_event_builder_factory() - self.handlers = hs.get_handlers() - self.message_handler = self.handlers.message_handler + self.event_creation_handler = hs.get_event_creation_handler() self.u_alice = UserID.from_string("@alice:test") self.u_bob = UserID.from_string("@bob:test") @@ -59,7 +58,7 @@ class RedactionTestCase(unittest.TestCase): "content": content, }) - event, context = yield self.message_handler._create_new_client_event( + event, context = yield self.event_creation_handler.create_new_client_event( builder ) @@ -79,7 +78,7 @@ class RedactionTestCase(unittest.TestCase): "content": {"body": body, "msgtype": u"message"}, }) - event, context = yield self.message_handler._create_new_client_event( + event, context = yield self.event_creation_handler.create_new_client_event( builder ) @@ -98,7 +97,7 @@ class RedactionTestCase(unittest.TestCase): "redacts": event_id, }) - event, context = yield self.message_handler._create_new_client_event( + event, context = yield self.event_creation_handler.create_new_client_event( builder ) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 316ecdb32d..7821ea3fa3 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -14,9 +14,9 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer +from tests import unittest from tests.utils import setup_test_homeserver @@ -42,9 +42,15 @@ class RegistrationStoreTestCase(unittest.TestCase): yield self.store.register(self.user_id, self.tokens[0], self.pwhash) self.assertEquals( - # TODO(paul): Surely this field should be 'user_id', not 'name' - # Additionally surely it shouldn't come in a 1-element list - {"name": self.user_id, "password_hash": self.pwhash, "is_guest": 0}, + { + # TODO(paul): Surely this field should be 'user_id', not 'name' + "name": self.user_id, + "password_hash": self.pwhash, + "is_guest": 0, + "consent_version": None, + "consent_server_notice_sent": None, + "appservice_id": None, + }, (yield self.store.get_user_by_id(self.user_id)) ) @@ -86,7 +92,8 @@ class RegistrationStoreTestCase(unittest.TestCase): # now delete some yield self.store.user_delete_access_tokens( - self.user_id, device_id=self.device_id, delete_refresh_tokens=True) + self.user_id, device_id=self.device_id, + ) # check they were deleted user = yield self.store.get_user_by_access_token(self.tokens[1]) @@ -97,8 +104,7 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertEqual(self.user_id, user["name"]) # now delete the rest - yield self.store.user_delete_access_tokens( - self.user_id, delete_refresh_tokens=True) + yield self.store.user_delete_access_tokens(self.user_id) user = yield self.store.get_user_by_access_token(self.tokens[0]) self.assertIsNone(user, diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index ef8a4d234f..ae8ae94b6d 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -14,12 +14,12 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.types import UserID, RoomID, RoomAlias +from synapse.types import RoomAlias, RoomID, UserID +from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 1be7d932f6..c5fd54f67e 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -14,16 +14,16 @@ # limitations under the License. -from tests import unittest +from mock import Mock + from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.types import UserID, RoomID +from synapse.types import RoomID, UserID +from tests import unittest from tests.utils import setup_test_homeserver -from mock import Mock - class RoomMemberStoreTestCase(unittest.TestCase): @@ -37,8 +37,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): # storage logic self.store = hs.get_datastore() self.event_builder_factory = hs.get_event_builder_factory() - self.handlers = hs.get_handlers() - self.message_handler = self.handlers.message_handler + self.event_creation_handler = hs.get_event_creation_handler() self.u_alice = UserID.from_string("@alice:test") self.u_bob = UserID.from_string("@bob:test") @@ -58,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): "content": {"membership": membership}, }) - event, context = yield self.message_handler._create_new_client_event( + event, context = yield self.event_creation_handler.create_new_client_event( builder ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py new file mode 100644 index 0000000000..23fad12bca --- /dev/null +++ b/tests/storage/test_user_directory.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.storage import UserDirectoryStore +from synapse.storage.roommember import ProfileInfo + +from tests import unittest +from tests.utils import setup_test_homeserver + +ALICE = "@alice:a" +BOB = "@bob:b" +BOBBY = "@bobby:a" + + +class UserDirectoryStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver() + self.store = UserDirectoryStore(None, self.hs) + + # alice and bob are both in !room_id. bobby is not but shares + # a homeserver with alice. + yield self.store.add_profiles_to_user_dir( + "!room:id", + { + ALICE: ProfileInfo(None, "alice"), + BOB: ProfileInfo(None, "bob"), + BOBBY: ProfileInfo(None, "bobby") + }, + ) + yield self.store.add_users_to_public_room( + "!room:id", + [ALICE, BOB], + ) + yield self.store.add_users_who_share_room( + "!room:id", + False, + ( + (ALICE, BOB), + (BOB, ALICE), + ), + ) + + @defer.inlineCallbacks + def test_search_user_dir(self): + # normally when alice searches the directory she should just find + # bob because bobby doesn't share a room with her. + r = yield self.store.search_user_dir(ALICE, "bob", 10) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual(r["results"][0], { + "user_id": BOB, + "display_name": "bob", + "avatar_url": None, + }) + + @defer.inlineCallbacks + def test_search_user_dir_all_users(self): + self.hs.config.user_directory_search_all_users = True + try: + r = yield self.store.search_user_dir(ALICE, "bob", 10) + self.assertFalse(r["limited"]) + self.assertEqual(2, len(r["results"])) + self.assertDictEqual(r["results"][0], { + "user_id": BOB, + "display_name": "bob", + "avatar_url": None, + }) + self.assertDictEqual(r["results"][1], { + "user_id": BOBBY, + "display_name": "bobby", + "avatar_url": None, + }) + finally: + self.hs.config.user_directory_search_all_users = False diff --git a/tests/test_distributor.py b/tests/test_distributor.py index acebcf4a86..71d11cda77 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import unittest -from twisted.internet import defer - from mock import Mock, patch from synapse.util.distributor import Distributor -from synapse.util.async import run_on_reactor + +from . import unittest class DistributorTestCase(unittest.TestCase): @@ -27,42 +26,19 @@ class DistributorTestCase(unittest.TestCase): def setUp(self): self.dist = Distributor() - @defer.inlineCallbacks def test_signal_dispatch(self): self.dist.declare("alert") observer = Mock() self.dist.observe("alert", observer) - d = self.dist.fire("alert", 1, 2, 3) - yield d - self.assertTrue(d.called) + self.dist.fire("alert", 1, 2, 3) observer.assert_called_with(1, 2, 3) - @defer.inlineCallbacks - def test_signal_dispatch_deferred(self): - self.dist.declare("whine") - - d_inner = defer.Deferred() - - def observer(): - return d_inner - - self.dist.observe("whine", observer) - - d_outer = self.dist.fire("whine") - - self.assertFalse(d_outer.called) - - d_inner.callback(None) - yield d_outer - self.assertTrue(d_outer.called) - - @defer.inlineCallbacks def test_signal_catch(self): self.dist.declare("alarm") - observers = [Mock() for i in 1, 2] + observers = [Mock() for i in (1, 2)] for o in observers: self.dist.observe("alarm", o) @@ -71,9 +47,7 @@ class DistributorTestCase(unittest.TestCase): with patch( "synapse.util.distributor.logger", spec=["warning"] ) as mock_logger: - d = self.dist.fire("alarm", "Go") - yield d - self.assertTrue(d.called) + self.dist.fire("alarm", "Go") observers[0].assert_called_once_with("Go") observers[1].assert_called_once_with("Go") @@ -83,35 +57,12 @@ class DistributorTestCase(unittest.TestCase): mock_logger.warning.call_args[0][0], str ) - @defer.inlineCallbacks - def test_signal_catch_no_suppress(self): - # Gut-wrenching - self.dist.suppress_failures = False - - self.dist.declare("whail") - - class MyException(Exception): - pass - - @defer.inlineCallbacks - def observer(): - yield run_on_reactor() - raise MyException("Oopsie") - - self.dist.observe("whail", observer) - - d = self.dist.fire("whail") - - yield self.assertFailure(d, MyException) - self.dist.suppress_failures = True - - @defer.inlineCallbacks def test_signal_prereg(self): observer = Mock() self.dist.observe("flare", observer) self.dist.declare("flare") - yield self.dist.fire("flare", 4, 5) + self.dist.fire("flare", 4, 5) observer.assert_called_with(4, 5) diff --git a/tests/test_dns.py b/tests/test_dns.py index c394c57ee7..b647d92697 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -13,26 +13,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import unittest +from mock import Mock + from twisted.internet import defer from twisted.names import dns, error -from mock import Mock - from synapse.http.endpoint import resolve_service from tests.utils import MockClock +from . import unittest + +@unittest.DEBUG class DnsTestCase(unittest.TestCase): @defer.inlineCallbacks def test_resolve(self): dns_client_mock = Mock() - service_name = "test_service.examle.com" + service_name = "test_service.example.com" host_name = "example.com" - ip_address = "127.0.0.1" answer_srv = dns.RRHeader( type=dns.SRV, @@ -41,16 +42,10 @@ class DnsTestCase(unittest.TestCase): ) ) - answer_a = dns.RRHeader( - type=dns.A, - payload=dns.Record_A( - address=ip_address, - ) + dns_client_mock.lookupService.return_value = defer.succeed( + ([answer_srv], None, None), ) - dns_client_mock.lookupService.return_value = ([answer_srv], None, None) - dns_client_mock.lookupAddress.return_value = ([answer_a], None, None) - cache = {} servers = yield resolve_service( @@ -58,18 +53,17 @@ class DnsTestCase(unittest.TestCase): ) dns_client_mock.lookupService.assert_called_once_with(service_name) - dns_client_mock.lookupAddress.assert_called_once_with(host_name) self.assertEquals(len(servers), 1) self.assertEquals(servers, cache[service_name]) - self.assertEquals(servers[0].host, ip_address) + self.assertEquals(servers[0].host, host_name) @defer.inlineCallbacks def test_from_cache_expired_and_dns_fail(self): dns_client_mock = Mock() dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) - service_name = "test_service.examle.com" + service_name = "test_service.example.com" entry = Mock(spec_set=["expires"]) entry.expires = 0 @@ -94,7 +88,7 @@ class DnsTestCase(unittest.TestCase): dns_client_mock = Mock(spec_set=['lookupService']) dns_client_mock.lookupService = Mock(spec_set=[]) - service_name = "test_service.examle.com" + service_name = "test_service.example.com" entry = Mock(spec_set=["expires"]) entry.expires = 999999999 @@ -118,7 +112,7 @@ class DnsTestCase(unittest.TestCase): dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) - service_name = "test_service.examle.com" + service_name = "test_service.example.com" cache = {} @@ -133,7 +127,7 @@ class DnsTestCase(unittest.TestCase): dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError()) - service_name = "test_service.examle.com" + service_name = "test_service.example.com" cache = {} diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py new file mode 100644 index 0000000000..06112430e5 --- /dev/null +++ b/tests/test_event_auth.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from synapse import event_auth +from synapse.api.errors import AuthError +from synapse.events import FrozenEvent + + +class EventAuthTestCase(unittest.TestCase): + def test_random_users_cannot_send_state_before_first_pl(self): + """ + Check that, before the first PL lands, the creator is the only user + that can send a state event. + """ + creator = "@creator:example.com" + joiner = "@joiner:example.com" + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.member", joiner): _join_event(joiner), + } + + # creator should be able to send state + event_auth.check( + _random_state_event(creator), auth_events, + do_sig_check=False, + ) + + # joiner should not be able to send state + self.assertRaises( + AuthError, + event_auth.check, + _random_state_event(joiner), + auth_events, + do_sig_check=False, + ), + + def test_state_default_level(self): + """ + Check that users above the state_default level can send state and + those below cannot + """ + creator = "@creator:example.com" + pleb = "@joiner:example.com" + king = "@joiner2:example.com" + + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.power_levels", ""): _power_levels_event(creator, { + "state_default": "30", + "users": { + pleb: "29", + king: "30", + }, + }), + ("m.room.member", pleb): _join_event(pleb), + ("m.room.member", king): _join_event(king), + } + + # pleb should not be able to send state + self.assertRaises( + AuthError, + event_auth.check, + _random_state_event(pleb), + auth_events, + do_sig_check=False, + ), + + # king should be able to send state + event_auth.check( + _random_state_event(king), auth_events, + do_sig_check=False, + ) + + +# helpers for making events + +TEST_ROOM_ID = "!test:room" + + +def _create_event(user_id): + return FrozenEvent({ + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.create", + "sender": user_id, + "content": { + "creator": user_id, + }, + }) + + +def _join_event(user_id): + return FrozenEvent({ + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.member", + "sender": user_id, + "state_key": user_id, + "content": { + "membership": "join", + }, + }) + + +def _power_levels_event(sender, content): + return FrozenEvent({ + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.power_levels", + "sender": sender, + "state_key": "", + "content": content, + }) + + +def _random_state_event(sender): + return FrozenEvent({ + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "test.state", + "sender": sender, + "state_key": "", + "content": { + "membership": "join", + }, + }) + + +event_count = 0 + + +def _get_event_id(): + global event_count + c = event_count + event_count += 1 + return "!%i:example.com" % (c, ) diff --git a/tests/test_federation.py b/tests/test_federation.py new file mode 100644 index 0000000000..f40ff29b52 --- /dev/null +++ b/tests/test_federation.py @@ -0,0 +1,242 @@ + +from mock import Mock + +from twisted.internet.defer import maybeDeferred, succeed + +from synapse.events import FrozenEvent +from synapse.types import Requester, UserID +from synapse.util import Clock + +from tests import unittest +from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver + + +class MessageAcceptTests(unittest.TestCase): + def setUp(self): + + self.http_client = Mock() + self.reactor = ThreadedMemoryReactorClock() + self.hs_clock = Clock(self.reactor) + self.homeserver = setup_test_homeserver( + http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor + ) + + user_id = UserID("us", "test") + our_user = Requester(user_id, None, False, None, None) + room_creator = self.homeserver.get_room_creation_handler() + room = room_creator.create_room( + our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False + ) + self.reactor.advance(0.1) + self.room_id = self.successResultOf(room)["room_id"] + + # Figure out what the most recent event is + most_recent = self.successResultOf( + maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + )[0] + + join_event = FrozenEvent( + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "state_key": "@baduser:test.serv", + "event_id": "$join:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.member", + "origin": "test.servx", + "content": {"membership": "join"}, + "auth_events": [], + "prev_state": [(most_recent, {})], + "prev_events": [(most_recent, {})], + } + ) + + self.handler = self.homeserver.get_handlers().federation_handler + self.handler.do_auth = lambda *a, **b: succeed(True) + self.client = self.homeserver.get_federation_client() + self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( + pdus + ) + + # Send the join, it should return None (which is not an error) + d = self.handler.on_receive_pdu( + "test.serv", join_event, sent_to_us_directly=True + ) + self.reactor.advance(1) + self.assertEqual(self.successResultOf(d), None) + + # Make sure we actually joined the room + self.assertEqual( + self.successResultOf( + maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + )[0], + "$join:test.serv", + ) + + def test_cant_hide_direct_ancestors(self): + """ + If you send a message, you must be able to provide the direct + prev_events that said event references. + """ + + def post_json(destination, path, data, headers=None, timeout=0): + # If it asks us for new missing events, give them NOTHING + if path.startswith("/_matrix/federation/v1/get_missing_events/"): + return {"events": []} + + self.http_client.post_json = post_json + + # Figure out what the most recent event is + most_recent = self.successResultOf( + maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + )[0] + + # Now lie about an event + lying_event = FrozenEvent( + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "event_id": "one:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "origin": "test.serv", + "content": "hewwo?", + "auth_events": [], + "prev_events": [("two:test.serv", {}), (most_recent, {})], + } + ) + + d = self.handler.on_receive_pdu( + "test.serv", lying_event, sent_to_us_directly=True + ) + + # Step the reactor, so the database fetches come back + self.reactor.advance(1) + + # on_receive_pdu should throw an error + failure = self.failureResultOf(d) + self.assertEqual( + failure.value.args[0], + ( + "ERROR 403: Your server isn't divulging details about prev_events " + "referenced in this event." + ), + ) + + # Make sure the invalid event isn't there + extrem = maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") + + def test_cant_hide_past_history(self): + """ + If you send a message, you must be able to provide the direct + prev_events that said event references. + """ + + def post_json(destination, path, data, headers=None, timeout=0): + if path.startswith("/_matrix/federation/v1/get_missing_events/"): + return { + "events": [ + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "event_id": "three:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "origin": "test.serv", + "content": "hewwo?", + "auth_events": [], + "prev_events": [("four:test.serv", {})], + } + ] + } + + self.http_client.post_json = post_json + + def get_json(destination, path, args, headers=None): + if path.startswith("/_matrix/federation/v1/state_ids/"): + d = self.successResultOf( + self.homeserver.datastore.get_state_ids_for_event("one:test.serv") + ) + + return succeed( + { + "pdu_ids": [ + y + for x, y in d.items() + if x == ("m.room.member", "@us:test") + ], + "auth_chain_ids": list(d.values()), + } + ) + + self.http_client.get_json = get_json + + # Figure out what the most recent event is + most_recent = self.successResultOf( + maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + )[0] + + # Make a good event + good_event = FrozenEvent( + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "event_id": "one:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "origin": "test.serv", + "content": "hewwo?", + "auth_events": [], + "prev_events": [(most_recent, {})], + } + ) + + d = self.handler.on_receive_pdu( + "test.serv", good_event, sent_to_us_directly=True + ) + self.reactor.advance(1) + self.assertEqual(self.successResultOf(d), None) + + bad_event = FrozenEvent( + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "event_id": "two:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "origin": "test.serv", + "content": "hewwo?", + "auth_events": [], + "prev_events": [("one:test.serv", {}), ("three:test.serv", {})], + } + ) + + d = self.handler.on_receive_pdu( + "test.serv", bad_event, sent_to_us_directly=True + ) + self.reactor.advance(1) + + extrem = maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv") + + state = self.homeserver.get_state_handler().get_current_state_ids(self.room_id) + self.reactor.advance(1) + self.assertIn(("m.room.member", "@us:test"), self.successResultOf(state).keys()) diff --git a/tests/test_preview.py b/tests/test_preview.py index 5bd36c74aa..446843367e 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import unittest - from synapse.rest.media.v1.preview_url_resource import ( - summarize_paragraphs, decode_and_calc_og + decode_and_calc_og, + summarize_paragraphs, ) +from . import unittest + class PreviewTestCase(unittest.TestCase): diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000000..7e063c0290 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,131 @@ +import json +import re + +from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactorClock + +from synapse.api.errors import Codes, SynapseError +from synapse.http.server import JsonResource +from synapse.util import Clock + +from tests import unittest +from tests.server import make_request, setup_test_homeserver + + +class JsonResourceTests(unittest.TestCase): + def setUp(self): + self.reactor = MemoryReactorClock() + self.hs_clock = Clock(self.reactor) + self.homeserver = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.reactor + ) + + def test_handler_for_request(self): + """ + JsonResource.handler_for_request gives correctly decoded URL args to + the callback, while Twisted will give the raw bytes of URL query + arguments. + """ + got_kwargs = {} + + def _callback(request, **kwargs): + got_kwargs.update(kwargs) + return (200, kwargs) + + res = JsonResource(self.homeserver) + res.register_paths( + "GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback + ) + + request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83") + request.render(res) + + self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) + self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"}) + + def test_callback_direct_exception(self): + """ + If the web callback raises an uncaught exception, it will be translated + into a 500. + """ + + def _callback(request, **kwargs): + raise Exception("boo") + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) + + request, channel = make_request(b"GET", b"/_matrix/foo") + request.render(res) + + self.assertEqual(channel.result["code"], b'500') + + def test_callback_indirect_exception(self): + """ + If the web callback raises an uncaught exception in a Deferred, it will + be translated into a 500. + """ + + def _throw(*args): + raise Exception("boo") + + def _callback(request, **kwargs): + d = Deferred() + d.addCallback(_throw) + self.reactor.callLater(1, d.callback, True) + return d + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) + + request, channel = make_request(b"GET", b"/_matrix/foo") + request.render(res) + + # No error has been raised yet + self.assertTrue("code" not in channel.result) + + # Advance time, now there's an error + self.reactor.advance(1) + self.assertEqual(channel.result["code"], b'500') + + def test_callback_synapseerror(self): + """ + If the web callback raises a SynapseError, it returns the appropriate + status code and message set in it. + """ + + def _callback(request, **kwargs): + raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN) + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) + + request, channel = make_request(b"GET", b"/_matrix/foo") + request.render(res) + + self.assertEqual(channel.result["code"], b'403') + reply_body = json.loads(channel.result["body"]) + self.assertEqual(reply_body["error"], "Forbidden!!one!") + self.assertEqual(reply_body["errcode"], "M_FORBIDDEN") + + def test_no_handler(self): + """ + If there is no handler to process the request, Synapse will return 400. + """ + + def _callback(request, **kwargs): + """ + Not ever actually called! + """ + self.fail("shouldn't ever get here") + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) + + request, channel = make_request(b"GET", b"/_matrix/foobar") + request.render(res) + + self.assertEqual(channel.result["code"], b'400') + reply_body = json.loads(channel.result["body"]) + self.assertEqual(reply_body["error"], "Unrecognized request") + self.assertEqual(reply_body["errcode"], "M_UNRECOGNIZED") diff --git a/tests/test_state.py b/tests/test_state.py index feb84f3d48..429a18cbf7 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -13,18 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests import unittest +from mock import Mock + from twisted.internet import defer -from synapse.events import FrozenEvent from synapse.api.auth import Auth from synapse.api.constants import EventTypes, Membership -from synapse.state import StateHandler - -from .utils import MockClock +from synapse.events import FrozenEvent +from synapse.state import StateHandler, StateResolutionHandler -from mock import Mock +from tests import unittest +from .utils import MockClock _next_event_id = 1000 @@ -80,14 +80,14 @@ class StateGroupStore(object): return defer.succeed(groups) - def store_state_groups(self, event, context): - if context.current_state_ids is None: - return + def store_state_group(self, event_id, room_id, prev_group, delta_ids, + current_state_ids): + state_group = self._next_group + self._next_group += 1 - state_events = dict(context.current_state_ids) + self._group_to_state[state_group] = dict(current_state_ids) - self._group_to_state[context.state_group] = state_events - self._event_to_state_group[event.event_id] = context.state_group + return state_group def get_events(self, event_ids, **kwargs): return { @@ -95,10 +95,19 @@ class StateGroupStore(object): if e_id in self._event_id_to_event } + def get_state_group_delta(self, name): + return (None, None) + def register_events(self, events): for e in events: self._event_id_to_event[e.event_id] = e + def register_event_context(self, event, context): + self._event_to_state_group[event.event_id] = context.state_group + + def register_event_id_state_group(self, event_id, state_group): + self._event_to_state_group[event_id] = state_group + class DictObj(dict): def __init__(self, **kwargs): @@ -137,25 +146,16 @@ class Graph(object): class StateTestCase(unittest.TestCase): def setUp(self): - self.store = Mock( - spec_set=[ - "get_state_groups_ids", - "add_event_hashes", - "get_events", - "get_next_state_group", - "get_state_group_delta", - ] - ) + self.store = StateGroupStore() hs = Mock(spec_set=[ "get_datastore", "get_auth", "get_state_handler", "get_clock", + "get_state_resolution_handler", ]) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) - - self.store.get_next_state_group.side_effect = Mock - self.store.get_state_group_delta.return_value = (None, None) + hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) self.state = StateHandler(hs) self.event_id = 0 @@ -195,17 +195,17 @@ class StateTestCase(unittest.TestCase): } ) - store = StateGroupStore() - self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids + self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) - store.store_state_groups(event, context) + self.store.register_event_context(event, context) context_store[event.event_id] = context - self.assertEqual(2, len(context_store["D"].prev_state_ids)) + prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) + self.assertEqual(2, len(prev_state_ids)) @defer.inlineCallbacks def test_branch_basic_conflict(self): @@ -247,21 +247,20 @@ class StateTestCase(unittest.TestCase): } ) - store = StateGroupStore() - self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids - self.store.get_events = store.get_events - store.register_events(graph.walk()) + self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) - store.store_state_groups(event, context) + self.store.register_event_context(event, context) context_store[event.event_id] = context + prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) + self.assertSetEqual( {"START", "A", "C"}, - {e_id for e_id in context_store["D"].prev_state_ids.values()} + {e_id for e_id in prev_state_ids.values()} ) @defer.inlineCallbacks @@ -313,21 +312,20 @@ class StateTestCase(unittest.TestCase): } ) - store = StateGroupStore() - self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids - self.store.get_events = store.get_events - store.register_events(graph.walk()) + self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) - store.store_state_groups(event, context) + self.store.register_event_context(event, context) context_store[event.event_id] = context + prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store) + self.assertSetEqual( {"START", "A", "B", "C"}, - {e for e in context_store["E"].prev_state_ids.values()} + {e for e in prev_state_ids.values()} ) @defer.inlineCallbacks @@ -396,21 +394,20 @@ class StateTestCase(unittest.TestCase): self._add_depths(nodes, edges) graph = Graph(nodes, edges) - store = StateGroupStore() - self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids - self.store.get_events = store.get_events - store.register_events(graph.walk()) + self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) - store.store_state_groups(event, context) + self.store.register_event_context(event, context) context_store[event.event_id] = context + prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) + self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, - {e for e in context_store["D"].prev_state_ids.values()} + {e for e in prev_state_ids.values()} ) def _add_depths(self, nodes, edges): @@ -439,8 +436,10 @@ class StateTestCase(unittest.TestCase): event, old_state=old_state ) + current_state_ids = yield context.get_current_state_ids(self.store) + self.assertEqual( - set(e.event_id for e in old_state), set(context.current_state_ids.values()) + set(e.event_id for e in old_state), set(current_state_ids.values()) ) self.assertIsNotNone(context.state_group) @@ -459,13 +458,19 @@ class StateTestCase(unittest.TestCase): event, old_state=old_state ) + prev_state_ids = yield context.get_prev_state_ids(self.store) + self.assertEqual( - set(e.event_id for e in old_state), set(context.prev_state_ids.values()) + set(e.event_id for e in old_state), set(prev_state_ids.values()) ) @defer.inlineCallbacks def test_trivial_annotate_message(self): - event = create_event(type="test_message", name="event") + prev_event_id = "prev_event_id" + event = create_event( + type="test_message", name="event2", + prev_events=[(prev_event_id, {})], + ) old_state = [ create_event(type="test1", state_key="1"), @@ -473,24 +478,30 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = "group_name_1" - - self.store.get_state_groups_ids.return_value = { - group_name: {(e.type, e.state_key): e.event_id for e in old_state}, - } + group_name = self.store.store_state_group( + prev_event_id, event.room_id, None, None, + {(e.type, e.state_key): e.event_id for e in old_state}, + ) + self.store.register_event_id_state_group(prev_event_id, group_name) context = yield self.state.compute_event_context(event) + current_state_ids = yield context.get_current_state_ids(self.store) + self.assertEqual( set([e.event_id for e in old_state]), - set(context.current_state_ids.values()) + set(current_state_ids.values()) ) self.assertEqual(group_name, context.state_group) @defer.inlineCallbacks def test_trivial_annotate_state(self): - event = create_event(type="state", state_key="", name="event") + prev_event_id = "prev_event_id" + event = create_event( + type="state", state_key="", name="event2", + prev_events=[(prev_event_id, {})], + ) old_state = [ create_event(type="test1", state_key="1"), @@ -498,24 +509,31 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = "group_name_1" - - self.store.get_state_groups_ids.return_value = { - group_name: {(e.type, e.state_key): e.event_id for e in old_state}, - } + group_name = self.store.store_state_group( + prev_event_id, event.room_id, None, None, + {(e.type, e.state_key): e.event_id for e in old_state}, + ) + self.store.register_event_id_state_group(prev_event_id, group_name) context = yield self.state.compute_event_context(event) + prev_state_ids = yield context.get_prev_state_ids(self.store) + self.assertEqual( set([e.event_id for e in old_state]), - set(context.prev_state_ids.values()) + set(prev_state_ids.values()) ) self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_message_conflict(self): - event = create_event(type="test_message", name="event") + prev_event_id1 = "event_id1" + prev_event_id2 = "event_id2" + event = create_event( + type="test_message", name="event3", + prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], + ) creation = create_event( type=EventTypes.Create, state_key="" @@ -535,20 +553,27 @@ class StateTestCase(unittest.TestCase): create_event(type="test4", state_key=""), ] - store = StateGroupStore() - store.register_events(old_state_1) - store.register_events(old_state_2) - self.store.get_events = store.get_events + self.store.register_events(old_state_1) + self.store.register_events(old_state_2) - context = yield self._get_context(event, old_state_1, old_state_2) + context = yield self._get_context( + event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, + ) + + current_state_ids = yield context.get_current_state_ids(self.store) - self.assertEqual(len(context.current_state_ids), 6) + self.assertEqual(len(current_state_ids), 6) self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_state_conflict(self): - event = create_event(type="test4", state_key="", name="event") + prev_event_id1 = "event_id1" + prev_event_id2 = "event_id2" + event = create_event( + type="test4", state_key="", name="event", + prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], + ) creation = create_event( type=EventTypes.Create, state_key="" @@ -573,15 +598,24 @@ class StateTestCase(unittest.TestCase): store.register_events(old_state_2) self.store.get_events = store.get_events - context = yield self._get_context(event, old_state_1, old_state_2) + context = yield self._get_context( + event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, + ) + + current_state_ids = yield context.get_current_state_ids(self.store) - self.assertEqual(len(context.current_state_ids), 6) + self.assertEqual(len(current_state_ids), 6) self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_standard_depth_conflict(self): - event = create_event(type="test4", name="event") + prev_event_id1 = "event_id1" + prev_event_id2 = "event_id2" + event = create_event( + type="test4", name="event", + prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], + ) member_event = create_event( type=EventTypes.Member, @@ -591,6 +625,14 @@ class StateTestCase(unittest.TestCase): } ) + power_levels = create_event( + type=EventTypes.PowerLevels, state_key="", + content={"users": { + "@foo:bar": "100", + "@user_id:example.com": "100", + }} + ) + creation = create_event( type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"} @@ -598,12 +640,14 @@ class StateTestCase(unittest.TestCase): old_state_1 = [ creation, + power_levels, member_event, create_event(type="test1", state_key="1", depth=1), ] old_state_2 = [ creation, + power_levels, member_event, create_event(type="test1", state_key="1", depth=2), ] @@ -613,10 +657,14 @@ class StateTestCase(unittest.TestCase): store.register_events(old_state_2) self.store.get_events = store.get_events - context = yield self._get_context(event, old_state_1, old_state_2) + context = yield self._get_context( + event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, + ) + + current_state_ids = yield context.get_current_state_ids(self.store) self.assertEqual( - old_state_2[2].event_id, context.current_state_ids[("test1", "1")] + old_state_2[3].event_id, current_state_ids[("test1", "1")] ) # Reverse the depth to make sure we are actually using the depths @@ -624,12 +672,14 @@ class StateTestCase(unittest.TestCase): old_state_1 = [ creation, + power_levels, member_event, create_event(type="test1", state_key="1", depth=2), ] old_state_2 = [ creation, + power_levels, member_event, create_event(type="test1", state_key="1", depth=1), ] @@ -637,19 +687,28 @@ class StateTestCase(unittest.TestCase): store.register_events(old_state_1) store.register_events(old_state_2) - context = yield self._get_context(event, old_state_1, old_state_2) + context = yield self._get_context( + event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, + ) + + current_state_ids = yield context.get_current_state_ids(self.store) self.assertEqual( - old_state_1[2].event_id, context.current_state_ids[("test1", "1")] + old_state_1[3].event_id, current_state_ids[("test1", "1")] ) - def _get_context(self, event, old_state_1, old_state_2): - group_name_1 = "group_name_1" - group_name_2 = "group_name_2" + def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2, + old_state_2): + sg1 = self.store.store_state_group( + prev_event_id_1, event.room_id, None, None, + {(e.type, e.state_key): e.event_id for e in old_state_1}, + ) + self.store.register_event_id_state_group(prev_event_id_1, sg1) - self.store.get_state_groups_ids.return_value = { - group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1}, - group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2}, - } + sg2 = self.store.store_state_group( + prev_event_id_2, event.room_id, None, None, + {(e.type, e.state_key): e.event_id for e in old_state_2}, + ) + self.store.register_event_id_state_group(prev_event_id_2, sg2) return self.state.compute_event_context(event) diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index d28bb726bb..bc97c12245 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -14,7 +14,6 @@ # limitations under the License. from tests import unittest - from tests.utils import MockClock diff --git a/tests/test_types.py b/tests/test_types.py index 24d61dbe54..729bd676c1 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests import unittest - from synapse.api.errors import SynapseError from synapse.server import HomeServer -from synapse.types import UserID, RoomAlias +from synapse.types import GroupID, RoomAlias, UserID + +from tests import unittest mock_homeserver = HomeServer(hostname="my.domain") @@ -60,3 +60,25 @@ class RoomAliasTestCase(unittest.TestCase): room = RoomAlias("channel", "my.domain") self.assertEquals(room.to_string(), "#channel:my.domain") + + +class GroupIDTestCase(unittest.TestCase): + def test_parse(self): + group_id = GroupID.from_string("+group/=_-.123:my.domain") + self.assertEqual("group/=_-.123", group_id.localpart) + self.assertEqual("my.domain", group_id.domain) + + def test_validate(self): + bad_ids = [ + "$badsigil:domain", + "+:empty", + ] + [ + "+group" + c + ":domain" for c in "A%?æ£" + ] + for id_string in bad_ids: + try: + GroupID.from_string(id_string) + self.fail("Parsing '%s' should raise exception" % id_string) + except SynapseError as exc: + self.assertEqual(400, exc.code) + self.assertEqual("M_UNKNOWN", exc.errcode) diff --git a/tests/test_visibility.py b/tests/test_visibility.py new file mode 100644 index 0000000000..0dc1a924d3 --- /dev/null +++ b/tests/test_visibility.py @@ -0,0 +1,324 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer +from twisted.internet.defer import succeed + +from synapse.events import FrozenEvent +from synapse.visibility import filter_events_for_server + +import tests.unittest +from tests.utils import setup_test_homeserver + +logger = logging.getLogger(__name__) + +TEST_ROOM_ID = "!TEST:ROOM" + + +class FilterEventsForServerTestCase(tests.unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver() + self.event_creation_handler = self.hs.get_event_creation_handler() + self.event_builder_factory = self.hs.get_event_builder_factory() + self.store = self.hs.get_datastore() + + @defer.inlineCallbacks + def test_filtering(self): + # + # The events to be filtered consist of 10 membership events (it doesn't + # really matter if they are joins or leaves, so let's make them joins). + # One of those membership events is going to be for a user on the + # server we are filtering for (so we can check the filtering is doing + # the right thing). + # + + # before we do that, we persist some other events to act as state. + self.inject_visibility("@admin:hs", "joined") + for i in range(0, 10): + yield self.inject_room_member("@resident%i:hs" % i) + + events_to_filter = [] + + for i in range(0, 10): + user = "@user%i:%s" % ( + i, "test_server" if i == 5 else "other_server" + ) + evt = yield self.inject_room_member(user, extra_content={"a": "b"}) + events_to_filter.append(evt) + + filtered = yield filter_events_for_server( + self.store, "test_server", events_to_filter, + ) + + # the result should be 5 redacted events, and 5 unredacted events. + for i in range(0, 5): + self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) + self.assertNotIn("a", filtered[i].content) + + for i in range(5, 10): + self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) + self.assertEqual(filtered[i].content["a"], "b") + + @tests.unittest.DEBUG + @defer.inlineCallbacks + def test_erased_user(self): + # 4 message events, from erased and unerased users, with a membership + # change in the middle of them. + events_to_filter = [] + + evt = yield self.inject_message("@unerased:local_hs") + events_to_filter.append(evt) + + evt = yield self.inject_message("@erased:local_hs") + events_to_filter.append(evt) + + evt = yield self.inject_room_member("@joiner:remote_hs") + events_to_filter.append(evt) + + evt = yield self.inject_message("@unerased:local_hs") + events_to_filter.append(evt) + + evt = yield self.inject_message("@erased:local_hs") + events_to_filter.append(evt) + + # the erasey user gets erased + self.hs.get_datastore().mark_user_erased("@erased:local_hs") + + # ... and the filtering happens. + filtered = yield filter_events_for_server( + self.store, "test_server", events_to_filter, + ) + + for i in range(0, len(events_to_filter)): + self.assertEqual( + events_to_filter[i].event_id, filtered[i].event_id, + "Unexpected event at result position %i" % (i, ) + ) + + for i in (0, 3): + self.assertEqual( + events_to_filter[i].content["body"], filtered[i].content["body"], + "Unexpected event content at result position %i" % (i,) + ) + + for i in (1, 4): + self.assertNotIn("body", filtered[i].content) + + @defer.inlineCallbacks + def inject_visibility(self, user_id, visibility): + content = {"history_visibility": visibility} + builder = self.event_builder_factory.new({ + "type": "m.room.history_visibility", + "sender": user_id, + "state_key": "", + "room_id": TEST_ROOM_ID, + "content": content, + }) + + event, context = yield self.event_creation_handler.create_new_client_event( + builder + ) + yield self.hs.get_datastore().persist_event(event, context) + defer.returnValue(event) + + @defer.inlineCallbacks + def inject_room_member(self, user_id, membership="join", extra_content={}): + content = {"membership": membership} + content.update(extra_content) + builder = self.event_builder_factory.new({ + "type": "m.room.member", + "sender": user_id, + "state_key": user_id, + "room_id": TEST_ROOM_ID, + "content": content, + }) + + event, context = yield self.event_creation_handler.create_new_client_event( + builder + ) + + yield self.hs.get_datastore().persist_event(event, context) + defer.returnValue(event) + + @defer.inlineCallbacks + def inject_message(self, user_id, content=None): + if content is None: + content = {"body": "testytest"} + builder = self.event_builder_factory.new({ + "type": "m.room.message", + "sender": user_id, + "room_id": TEST_ROOM_ID, + "content": content, + }) + + event, context = yield self.event_creation_handler.create_new_client_event( + builder + ) + + yield self.hs.get_datastore().persist_event(event, context) + defer.returnValue(event) + + @defer.inlineCallbacks + def test_large_room(self): + # see what happens when we have a large room with hundreds of thousands + # of membership events + + # As above, the events to be filtered consist of 10 membership events, + # where one of them is for a user on the server we are filtering for. + + import cProfile + import pstats + import time + + # we stub out the store, because building up all that state the normal + # way is very slow. + test_store = _TestStore() + + # our initial state is 100000 membership events and one + # history_visibility event. + room_state = [] + + history_visibility_evt = FrozenEvent({ + "event_id": "$history_vis", + "type": "m.room.history_visibility", + "sender": "@resident_user_0:test.com", + "state_key": "", + "room_id": TEST_ROOM_ID, + "content": {"history_visibility": "joined"}, + }) + room_state.append(history_visibility_evt) + test_store.add_event(history_visibility_evt) + + for i in range(0, 100000): + user = "@resident_user_%i:test.com" % (i, ) + evt = FrozenEvent({ + "event_id": "$res_event_%i" % (i, ), + "type": "m.room.member", + "state_key": user, + "sender": user, + "room_id": TEST_ROOM_ID, + "content": { + "membership": "join", + "extra": "zzz," + }, + }) + room_state.append(evt) + test_store.add_event(evt) + + events_to_filter = [] + for i in range(0, 10): + user = "@user%i:%s" % ( + i, "test_server" if i == 5 else "other_server" + ) + evt = FrozenEvent({ + "event_id": "$evt%i" % (i, ), + "type": "m.room.member", + "state_key": user, + "sender": user, + "room_id": TEST_ROOM_ID, + "content": { + "membership": "join", + "extra": "zzz", + }, + }) + events_to_filter.append(evt) + room_state.append(evt) + + test_store.add_event(evt) + test_store.set_state_ids_for_event(evt, { + (e.type, e.state_key): e.event_id for e in room_state + }) + + pr = cProfile.Profile() + pr.enable() + + logger.info("Starting filtering") + start = time.time() + filtered = yield filter_events_for_server( + test_store, "test_server", events_to_filter, + ) + logger.info("Filtering took %f seconds", time.time() - start) + + pr.disable() + with open("filter_events_for_server.profile", "w+") as f: + ps = pstats.Stats(pr, stream=f).sort_stats('cumulative') + ps.print_stats() + + # the result should be 5 redacted events, and 5 unredacted events. + for i in range(0, 5): + self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) + self.assertNotIn("extra", filtered[i].content) + + for i in range(5, 10): + self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) + self.assertEqual(filtered[i].content["extra"], "zzz") + + test_large_room.skip = "Disabled by default because it's slow" + + +class _TestStore(object): + """Implements a few methods of the DataStore, so that we can test + filter_events_for_server + + """ + def __init__(self): + # data for get_events: a map from event_id to event + self.events = {} + + # data for get_state_ids_for_events mock: a map from event_id to + # a map from (type_state_key) -> event_id for the state at that + # event + self.state_ids_for_events = {} + + def add_event(self, event): + self.events[event.event_id] = event + + def set_state_ids_for_event(self, event, state): + self.state_ids_for_events[event.event_id] = state + + def get_state_ids_for_events(self, events, types): + res = {} + include_memberships = False + for (type, state_key) in types: + if type == "m.room.history_visibility": + continue + if type != "m.room.member" or state_key is not None: + raise RuntimeError( + "Unimplemented: get_state_ids with type (%s, %s)" % + (type, state_key), + ) + include_memberships = True + + if include_memberships: + for event_id in events: + res[event_id] = self.state_ids_for_events[event_id] + + else: + k = ("m.room.history_visibility", "") + for event_id in events: + hve = self.state_ids_for_events[event_id][k] + res[event_id] = {k: hve} + + return succeed(res) + + def get_events(self, events): + return succeed({ + event_id: self.events[event_id] for event_id in events + }) + + def are_users_erased(self, users): + return succeed({u: False for u in users}) diff --git a/tests/unittest.py b/tests/unittest.py index 38715972dd..b15b06726b 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -13,22 +13,39 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + +import twisted +import twisted.logger from twisted.trial import unittest -import logging +from synapse.util.logcontext import LoggingContextFilter + +# Set up putting Synapse's logs into Trial's. +rootLogger = logging.getLogger() + +log_format = ( + "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s" +) + -# logging doesn't have a "don't log anything at all EVARRRR setting, -# but since the highest value is 50, 1000000 should do ;) -NEVER = 1000000 +class ToTwistedHandler(logging.Handler): + tx_log = twisted.logger.Logger() -handler = logging.StreamHandler() -handler.setFormatter(logging.Formatter( - "%(levelname)s:%(name)s:%(message)s [%(pathname)s:%(lineno)d]" -)) -logging.getLogger().addHandler(handler) -logging.getLogger().setLevel(NEVER) -logging.getLogger("synapse.storage.SQL").setLevel(NEVER) -logging.getLogger("synapse.storage.txn").setLevel(NEVER) + def emit(self, record): + log_entry = self.format(record) + log_level = record.levelname.lower().replace('warning', 'warn') + self.tx_log.emit( + twisted.logger.LogLevel.levelWithName(log_level), + log_entry.replace("{", r"(").replace("}", r")"), + ) + + +handler = ToTwistedHandler() +formatter = logging.Formatter(log_format) +handler.setFormatter(formatter) +handler.addFilter(LoggingContextFilter(request="")) +rootLogger.addHandler(handler) def around(target): @@ -61,10 +78,14 @@ class TestCase(unittest.TestCase): method = getattr(self, methodName) - level = getattr(method, "loglevel", getattr(self, "loglevel", NEVER)) + level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR)) @around(self) def setUp(orig): + # enable debugging of delayed calls - this means that we get a + # traceback when a unit test exits leaving things on the reactor. + twisted.internet.base.DelayedCall.debug = True + old_level = logging.getLogger().level if old_level != level: @@ -88,6 +109,17 @@ class TestCase(unittest.TestCase): except AssertionError as e: raise (type(e))(e.message + " for '.%s'" % key) + def assert_dict(self, required, actual): + """Does a partial assert of a dict. + + Args: + required (dict): The keys and value which MUST be in 'actual'. + actual (dict): The test result. Extra keys will not be checked. + """ + for key in required: + self.assertEquals(required[key], actual[key], + msg="%s mismatch. %s" % (key, actual)) + def DEBUG(target): """A decorator to set the .loglevel attribute to logging.DEBUG. diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 3f14ab503f..8176a7dabd 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,18 +14,71 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from functools import partial import mock + +from twisted.internet import defer, reactor + from synapse.api.errors import SynapseError -from synapse.util import async from synapse.util import logcontext -from twisted.internet import defer from synapse.util.caches import descriptors + from tests import unittest logger = logging.getLogger(__name__) +def run_on_reactor(): + d = defer.Deferred() + reactor.callLater(0, d.callback, 0) + return logcontext.make_deferred_yieldable(d) + + +class CacheTestCase(unittest.TestCase): + def test_invalidate_all(self): + cache = descriptors.Cache("testcache") + + callback_record = [False, False] + + def record_callback(idx): + callback_record[idx] = True + + # add a couple of pending entries + d1 = defer.Deferred() + cache.set("key1", d1, partial(record_callback, 0)) + + d2 = defer.Deferred() + cache.set("key2", d2, partial(record_callback, 1)) + + # lookup should return the deferreds + self.assertIs(cache.get("key1"), d1) + self.assertIs(cache.get("key2"), d2) + + # let one of the lookups complete + d2.callback("result2") + self.assertEqual(cache.get("key2"), "result2") + + # now do the invalidation + cache.invalidate_all() + + # lookup should return none + self.assertIsNone(cache.get("key1", None)) + self.assertIsNone(cache.get("key2", None)) + + # both callbacks should have been callbacked + self.assertTrue( + callback_record[0], "Invalidation callback for key1 not called", + ) + self.assertTrue( + callback_record[1], "Invalidation callback for key2 not called", + ) + + # letting the other lookup complete should do nothing + d1.callback("result1") + self.assertIsNone(cache.get("key1", None)) + + class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cache(self): @@ -149,7 +203,8 @@ class DescriptorTestCase(unittest.TestCase): def fn(self, arg1): @defer.inlineCallbacks def inner_fn(): - yield async.run_on_reactor() + # we want this to behave like an asynchronous function + yield run_on_reactor() raise SynapseError(400, "blah") return inner_fn() @@ -159,7 +214,12 @@ class DescriptorTestCase(unittest.TestCase): with logcontext.LoggingContext() as c1: c1.name = "c1" try: - yield obj.fn(1) + d = obj.fn(1) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) + yield d self.fail("No exception thrown") except SynapseError: pass diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index bc92f85fa6..26f2fa5800 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -14,10 +14,10 @@ # limitations under the License. -from tests import unittest - from synapse.util.caches.dictionary_cache import DictionaryCache +from tests import unittest + class DictCacheTestCase(unittest.TestCase): @@ -32,7 +32,7 @@ class DictCacheTestCase(unittest.TestCase): seq = self.cache.sequence test_value = {"test": "test_simple_cache_hit_full"} - self.cache.update(seq, key, test_value, full=True) + self.cache.update(seq, key, test_value) c = self.cache.get(key) self.assertEqual(test_value, c.value) @@ -44,7 +44,7 @@ class DictCacheTestCase(unittest.TestCase): test_value = { "test": "test_simple_cache_hit_partial" } - self.cache.update(seq, key, test_value, full=True) + self.cache.update(seq, key, test_value) c = self.cache.get(key, ["test"]) self.assertEqual(test_value, c.value) @@ -56,7 +56,7 @@ class DictCacheTestCase(unittest.TestCase): test_value = { "test": "test_simple_cache_miss_partial" } - self.cache.update(seq, key, test_value, full=True) + self.cache.update(seq, key, test_value) c = self.cache.get(key, ["test2"]) self.assertEqual({}, c.value) @@ -70,7 +70,7 @@ class DictCacheTestCase(unittest.TestCase): "test2": "test_simple_cache_hit_miss_partial2", "test3": "test_simple_cache_hit_miss_partial3", } - self.cache.update(seq, key, test_value, full=True) + self.cache.update(seq, key, test_value) c = self.cache.get(key, ["test2"]) self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value) @@ -82,13 +82,13 @@ class DictCacheTestCase(unittest.TestCase): test_value_1 = { "test": "test_simple_cache_hit_miss_partial", } - self.cache.update(seq, key, test_value_1, full=False) + self.cache.update(seq, key, test_value_1, fetched_keys=set("test")) seq = self.cache.sequence test_value_2 = { "test2": "test_simple_cache_hit_miss_partial2", } - self.cache.update(seq, key, test_value_2, full=False) + self.cache.update(seq, key, test_value_2, fetched_keys=set("test2")) c = self.cache.get(key) self.assertEqual( diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index 31d24adb8b..d12b5e838b 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -14,12 +14,12 @@ # limitations under the License. -from .. import unittest - from synapse.util.caches.expiringcache import ExpiringCache from tests.utils import MockClock +from .. import unittest + class ExpiringCacheTestCase(unittest.TestCase): diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py new file mode 100644 index 0000000000..7ce5f8c258 --- /dev/null +++ b/tests/util/test_file_consumer.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import threading + +from mock import NonCallableMock +from six import StringIO + +from twisted.internet import defer, reactor + +from synapse.util.file_consumer import BackgroundFileConsumer + +from tests import unittest + + +class FileConsumerTests(unittest.TestCase): + + @defer.inlineCallbacks + def test_pull_consumer(self): + string_file = StringIO() + consumer = BackgroundFileConsumer(string_file, reactor=reactor) + + try: + producer = DummyPullProducer() + + yield producer.register_with_consumer(consumer) + + yield producer.write_and_wait("Foo") + + self.assertEqual(string_file.getvalue(), "Foo") + + yield producer.write_and_wait("Bar") + + self.assertEqual(string_file.getvalue(), "FooBar") + finally: + consumer.unregisterProducer() + + yield consumer.wait() + + self.assertTrue(string_file.closed) + + @defer.inlineCallbacks + def test_push_consumer(self): + string_file = BlockingStringWrite() + consumer = BackgroundFileConsumer(string_file, reactor=reactor) + + try: + producer = NonCallableMock(spec_set=[]) + + consumer.registerProducer(producer, True) + + consumer.write("Foo") + yield string_file.wait_for_n_writes(1) + + self.assertEqual(string_file.buffer, "Foo") + + consumer.write("Bar") + yield string_file.wait_for_n_writes(2) + + self.assertEqual(string_file.buffer, "FooBar") + finally: + consumer.unregisterProducer() + + yield consumer.wait() + + self.assertTrue(string_file.closed) + + @defer.inlineCallbacks + def test_push_producer_feedback(self): + string_file = BlockingStringWrite() + consumer = BackgroundFileConsumer(string_file, reactor=reactor) + + try: + producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"]) + + resume_deferred = defer.Deferred() + producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None) + + consumer.registerProducer(producer, True) + + number_writes = 0 + with string_file.write_lock: + for _ in range(consumer._PAUSE_ON_QUEUE_SIZE): + consumer.write("Foo") + number_writes += 1 + + producer.pauseProducing.assert_called_once() + + yield string_file.wait_for_n_writes(number_writes) + + yield resume_deferred + producer.resumeProducing.assert_called_once() + finally: + consumer.unregisterProducer() + + yield consumer.wait() + + self.assertTrue(string_file.closed) + + +class DummyPullProducer(object): + def __init__(self): + self.consumer = None + self.deferred = defer.Deferred() + + def resumeProducing(self): + d = self.deferred + self.deferred = defer.Deferred() + d.callback(None) + + def write_and_wait(self, bytes): + d = self.deferred + self.consumer.write(bytes) + return d + + def register_with_consumer(self, consumer): + d = self.deferred + self.consumer = consumer + self.consumer.registerProducer(self, False) + return d + + +class BlockingStringWrite(object): + def __init__(self): + self.buffer = "" + self.closed = False + self.write_lock = threading.Lock() + + self._notify_write_deferred = None + self._number_of_writes = 0 + + def write(self, bytes): + with self.write_lock: + self.buffer += bytes + self._number_of_writes += 1 + + reactor.callFromThread(self._notify_write) + + def close(self): + self.closed = True + + def _notify_write(self): + "Called by write to indicate a write happened" + with self.write_lock: + if not self._notify_write_deferred: + return + d = self._notify_write_deferred + self._notify_write_deferred = None + d.callback(None) + + @defer.inlineCallbacks + def wait_for_n_writes(self, n): + "Wait for n writes to have happened" + while True: + with self.write_lock: + if n <= self._number_of_writes: + return + + if not self._notify_write_deferred: + self._notify_write_deferred = defer.Deferred() + + d = self._notify_write_deferred + + yield d diff --git a/tests/util/test_limiter.py b/tests/util/test_limiter.py deleted file mode 100644 index 9c795d9fdb..0000000000 --- a/tests/util/test_limiter.py +++ /dev/null @@ -1,70 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from tests import unittest - -from twisted.internet import defer - -from synapse.util.async import Limiter - - -class LimiterTestCase(unittest.TestCase): - - @defer.inlineCallbacks - def test_limiter(self): - limiter = Limiter(3) - - key = object() - - d1 = limiter.queue(key) - cm1 = yield d1 - - d2 = limiter.queue(key) - cm2 = yield d2 - - d3 = limiter.queue(key) - cm3 = yield d3 - - d4 = limiter.queue(key) - self.assertFalse(d4.called) - - d5 = limiter.queue(key) - self.assertFalse(d5.called) - - with cm1: - self.assertFalse(d4.called) - self.assertFalse(d5.called) - - self.assertTrue(d4.called) - self.assertFalse(d5.called) - - with cm3: - self.assertFalse(d5.called) - - self.assertTrue(d5.called) - - with cm2: - pass - - with (yield d4): - pass - - with (yield d5): - pass - - d6 = limiter.queue(key) - with (yield d6): - pass diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index afcba482f9..4729bd5a0a 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from six.moves import range -from tests import unittest - -from twisted.internet import defer +from twisted.internet import defer, reactor +from twisted.internet.defer import CancelledError +from synapse.util import Clock, logcontext from synapse.util.async import Linearizer +from tests import unittest + class LinearizerTestCase(unittest.TestCase): @@ -38,7 +42,104 @@ class LinearizerTestCase(unittest.TestCase): with cm1: self.assertFalse(d2.called) + with (yield d2): + pass + + def test_lots_of_queued_things(self): + # we have one slow thing, and lots of fast things queued up behind it. + # it should *not* explode the stack. + linearizer = Linearizer() + + @defer.inlineCallbacks + def func(i, sleep=False): + with logcontext.LoggingContext("func(%s)" % i) as lc: + with (yield linearizer.queue("")): + self.assertEqual( + logcontext.LoggingContext.current_context(), lc) + if sleep: + yield Clock(reactor).sleep(0) + + self.assertEqual( + logcontext.LoggingContext.current_context(), lc) + + func(0, sleep=True) + for i in range(1, 100): + func(i) + + return func(1000) + + @defer.inlineCallbacks + def test_multiple_entries(self): + limiter = Linearizer(max_count=3) + + key = object() + + d1 = limiter.queue(key) + cm1 = yield d1 + + d2 = limiter.queue(key) + cm2 = yield d2 + + d3 = limiter.queue(key) + cm3 = yield d3 + + d4 = limiter.queue(key) + self.assertFalse(d4.called) + + d5 = limiter.queue(key) + self.assertFalse(d5.called) + + with cm1: + self.assertFalse(d4.called) + self.assertFalse(d5.called) + + cm4 = yield d4 + self.assertFalse(d5.called) + + with cm3: + self.assertFalse(d5.called) + + cm5 = yield d5 + + with cm2: + pass + + with cm4: + pass + + with cm5: + pass + + d6 = limiter.queue(key) + with (yield d6): + pass + + @defer.inlineCallbacks + def test_cancellation(self): + linearizer = Linearizer() + + key = object() + + d1 = linearizer.queue(key) + cm1 = yield d1 + + d2 = linearizer.queue(key) + self.assertFalse(d2.called) + + d3 = linearizer.queue(key) + self.assertFalse(d3.called) + + d2.cancel() + + with cm1: + pass + self.assertTrue(d2.called) + try: + yield d2 + self.fail("Expected d2 to raise CancelledError") + except CancelledError: + pass - with (yield d2): + with (yield d3): pass diff --git a/tests/util/test_log_context.py b/tests/util/test_log_context.py deleted file mode 100644 index 9ffe209c4d..0000000000 --- a/tests/util/test_log_context.py +++ /dev/null @@ -1,96 +0,0 @@ -import twisted.python.failure -from twisted.internet import defer -from twisted.internet import reactor -from .. import unittest - -from synapse.util.async import sleep -from synapse.util import logcontext -from synapse.util.logcontext import LoggingContext - - -class LoggingContextTestCase(unittest.TestCase): - - def _check_test_key(self, value): - self.assertEquals( - LoggingContext.current_context().test_key, value - ) - - def test_with_context(self): - with LoggingContext() as context_one: - context_one.test_key = "test" - self._check_test_key("test") - - @defer.inlineCallbacks - def test_sleep(self): - @defer.inlineCallbacks - def competing_callback(): - with LoggingContext() as competing_context: - competing_context.test_key = "competing" - yield sleep(0) - self._check_test_key("competing") - - reactor.callLater(0, competing_callback) - - with LoggingContext() as context_one: - context_one.test_key = "one" - yield sleep(0) - self._check_test_key("one") - - def _test_preserve_fn(self, function): - sentinel_context = LoggingContext.current_context() - - callback_completed = [False] - - @defer.inlineCallbacks - def cb(): - context_one.test_key = "one" - yield function() - self._check_test_key("one") - - callback_completed[0] = True - - with LoggingContext() as context_one: - context_one.test_key = "one" - - # fire off function, but don't wait on it. - logcontext.preserve_fn(cb)() - - self._check_test_key("one") - - # now wait for the function under test to have run, and check that - # the logcontext is left in a sane state. - d2 = defer.Deferred() - - def check_logcontext(): - if not callback_completed[0]: - reactor.callLater(0.01, check_logcontext) - return - - # make sure that the context was reset before it got thrown back - # into the reactor - try: - self.assertIs(LoggingContext.current_context(), - sentinel_context) - d2.callback(None) - except BaseException: - d2.errback(twisted.python.failure.Failure()) - - reactor.callLater(0.01, check_logcontext) - - # test is done once d2 finishes - return d2 - - def test_preserve_fn_with_blocking_fn(self): - @defer.inlineCallbacks - def blocking_function(): - yield sleep(0) - - return self._test_preserve_fn(blocking_function) - - def test_preserve_fn_with_non_blocking_fn(self): - @defer.inlineCallbacks - def nonblocking_function(): - with logcontext.PreserveLoggingContext(): - yield defer.succeed(None) - - return self._test_preserve_fn(nonblocking_function) diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py new file mode 100644 index 0000000000..c54001f7a4 --- /dev/null +++ b/tests/util/test_logcontext.py @@ -0,0 +1,179 @@ +import twisted.python.failure +from twisted.internet import defer, reactor + +from synapse.util import Clock, logcontext +from synapse.util.logcontext import LoggingContext + +from .. import unittest + + +class LoggingContextTestCase(unittest.TestCase): + + def _check_test_key(self, value): + self.assertEquals( + LoggingContext.current_context().request, value + ) + + def test_with_context(self): + with LoggingContext() as context_one: + context_one.request = "test" + self._check_test_key("test") + + @defer.inlineCallbacks + def test_sleep(self): + clock = Clock(reactor) + + @defer.inlineCallbacks + def competing_callback(): + with LoggingContext() as competing_context: + competing_context.request = "competing" + yield clock.sleep(0) + self._check_test_key("competing") + + reactor.callLater(0, competing_callback) + + with LoggingContext() as context_one: + context_one.request = "one" + yield clock.sleep(0) + self._check_test_key("one") + + def _test_run_in_background(self, function): + sentinel_context = LoggingContext.current_context() + + callback_completed = [False] + + def test(): + context_one.request = "one" + d = function() + + def cb(res): + self._check_test_key("one") + callback_completed[0] = True + return res + d.addCallback(cb) + + return d + + with LoggingContext() as context_one: + context_one.request = "one" + + # fire off function, but don't wait on it. + logcontext.run_in_background(test) + + self._check_test_key("one") + + # now wait for the function under test to have run, and check that + # the logcontext is left in a sane state. + d2 = defer.Deferred() + + def check_logcontext(): + if not callback_completed[0]: + reactor.callLater(0.01, check_logcontext) + return + + # make sure that the context was reset before it got thrown back + # into the reactor + try: + self.assertIs(LoggingContext.current_context(), + sentinel_context) + d2.callback(None) + except BaseException: + d2.errback(twisted.python.failure.Failure()) + + reactor.callLater(0.01, check_logcontext) + + # test is done once d2 finishes + return d2 + + def test_run_in_background_with_blocking_fn(self): + @defer.inlineCallbacks + def blocking_function(): + yield Clock(reactor).sleep(0) + + return self._test_run_in_background(blocking_function) + + def test_run_in_background_with_non_blocking_fn(self): + @defer.inlineCallbacks + def nonblocking_function(): + with logcontext.PreserveLoggingContext(): + yield defer.succeed(None) + + return self._test_run_in_background(nonblocking_function) + + def test_run_in_background_with_chained_deferred(self): + # a function which returns a deferred which looks like it has been + # called, but is actually paused + def testfunc(): + return logcontext.make_deferred_yieldable( + _chained_deferred_function() + ) + + return self._test_run_in_background(testfunc) + + @defer.inlineCallbacks + def test_make_deferred_yieldable(self): + # a function which retuns an incomplete deferred, but doesn't follow + # the synapse rules. + def blocking_function(): + d = defer.Deferred() + reactor.callLater(0, d.callback, None) + return d + + sentinel_context = LoggingContext.current_context() + + with LoggingContext() as context_one: + context_one.request = "one" + + d1 = logcontext.make_deferred_yieldable(blocking_function()) + # make sure that the context was reset by make_deferred_yieldable + self.assertIs(LoggingContext.current_context(), sentinel_context) + + yield d1 + + # now it should be restored + self._check_test_key("one") + + @defer.inlineCallbacks + def test_make_deferred_yieldable_with_chained_deferreds(self): + sentinel_context = LoggingContext.current_context() + + with LoggingContext() as context_one: + context_one.request = "one" + + d1 = logcontext.make_deferred_yieldable(_chained_deferred_function()) + # make sure that the context was reset by make_deferred_yieldable + self.assertIs(LoggingContext.current_context(), sentinel_context) + + yield d1 + + # now it should be restored + self._check_test_key("one") + + @defer.inlineCallbacks + def test_make_deferred_yieldable_on_non_deferred(self): + """Check that make_deferred_yieldable does the right thing when its + argument isn't actually a deferred""" + + with LoggingContext() as context_one: + context_one.request = "one" + + d1 = logcontext.make_deferred_yieldable("bum") + self._check_test_key("one") + + r = yield d1 + self.assertEqual(r, "bum") + self._check_test_key("one") + + +# a function which returns a deferred which has been "called", but +# which had a function which returned another incomplete deferred on +# its callback list, so won't yet call any other new callbacks. +def _chained_deferred_function(): + d = defer.succeed(None) + + def cb(res): + d2 = defer.Deferred() + reactor.callLater(0, d2.callback, res) + return d2 + d.addCallback(cb) + return d diff --git a/tests/util/test_clock.py b/tests/util/test_logformatter.py index 9672603579..297aebbfbe 100644 --- a/tests/util/test_clock.py +++ b/tests/util/test_logformatter.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,22 +12,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse import util -from twisted.internet import defer +import sys + +from synapse.util.logformatter import LogFormatter + from tests import unittest -class ClockTestCase(unittest.TestCase): - @defer.inlineCallbacks - def test_time_bound_deferred(self): - # just a deferred which never resolves - slow_deferred = defer.Deferred() +class TestException(Exception): + pass + - clock = util.Clock() - time_bound = clock.time_bound_deferred(slow_deferred, 0.001) +class LogFormatterTestCase(unittest.TestCase): + def test_formatter(self): + formatter = LogFormatter() try: - yield time_bound - self.fail("Expected timedout error, but got nothing") - except util.DeferredTimedOutError: - pass + raise TestException("testytest") + except TestException: + ei = sys.exc_info() + + output = formatter.formatException(ei) + + # check the output looks vaguely sane + self.assertIn("testytest", output) + self.assertIn("Capture point", output) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index dfb78cb8bd..9b36ef4482 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -14,12 +14,12 @@ # limitations under the License. -from .. import unittest +from mock import Mock from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache -from mock import Mock +from .. import unittest class LruCacheTestCase(unittest.TestCase): diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index 1d745ae1a7..24194e3b25 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -14,10 +14,10 @@ # limitations under the License. -from tests import unittest - from synapse.util.async import ReadWriteLock +from tests import unittest + class ReadWriteLockTestCase(unittest.TestCase): diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py index d3a8630c2f..0f5b32fcc0 100644 --- a/tests/util/test_snapshot_cache.py +++ b/tests/util/test_snapshot_cache.py @@ -14,10 +14,11 @@ # limitations under the License. -from .. import unittest +from twisted.internet.defer import Deferred from synapse.util.caches.snapshot_cache import SnapshotCache -from twisted.internet.defer import Deferred + +from .. import unittest class SnapshotCacheTestCase(unittest.TestCase): diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py new file mode 100644 index 0000000000..65b0f2e6fb --- /dev/null +++ b/tests/util/test_stream_change_cache.py @@ -0,0 +1,215 @@ +from mock import patch + +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from tests import unittest + + +class StreamChangeCacheTests(unittest.TestCase): + """ + Tests for StreamChangeCache. + """ + + def test_prefilled_cache(self): + """ + Providing a prefilled cache to StreamChangeCache will result in a cache + with the prefilled-cache entered in. + """ + cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2}) + self.assertTrue(cache.has_entity_changed("user@foo.com", 1)) + + def test_has_entity_changed(self): + """ + StreamChangeCache.entity_has_changed will mark entities as changed, and + has_entity_changed will observe the changed entities. + """ + cache = StreamChangeCache("#test", 3) + + cache.entity_has_changed("user@foo.com", 6) + cache.entity_has_changed("bar@baz.net", 7) + + # If it's been changed after that stream position, return True + self.assertTrue(cache.has_entity_changed("user@foo.com", 4)) + self.assertTrue(cache.has_entity_changed("bar@baz.net", 4)) + + # If it's been changed at that stream position, return False + self.assertFalse(cache.has_entity_changed("user@foo.com", 6)) + + # If there's no changes after that stream position, return False + self.assertFalse(cache.has_entity_changed("user@foo.com", 7)) + + # If the entity does not exist, return False. + self.assertFalse(cache.has_entity_changed("not@here.website", 7)) + + # If we request before the stream cache's earliest known position, + # return True, whether it's a known entity or not. + self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) + self.assertTrue(cache.has_entity_changed("not@here.website", 0)) + + @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0) + def test_has_entity_changed_pops_off_start(self): + """ + StreamChangeCache.entity_has_changed will respect the max size and + purge the oldest items upon reaching that max size. + """ + cache = StreamChangeCache("#test", 1, max_size=2) + + cache.entity_has_changed("user@foo.com", 2) + cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("user@elsewhere.org", 4) + + # The cache is at the max size, 2 + self.assertEqual(len(cache._cache), 2) + + # The oldest item has been popped off + self.assertTrue("user@foo.com" not in cache._entity_to_key) + + # If we update an existing entity, it keeps the two existing entities + cache.entity_has_changed("bar@baz.net", 5) + self.assertEqual( + set(["bar@baz.net", "user@elsewhere.org"]), set(cache._entity_to_key) + ) + + def test_get_all_entities_changed(self): + """ + StreamChangeCache.get_all_entities_changed will return all changed + entities since the given position. If the position is before the start + of the known stream, it returns None instead. + """ + cache = StreamChangeCache("#test", 1) + + cache.entity_has_changed("user@foo.com", 2) + cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("user@elsewhere.org", 4) + + self.assertEqual( + cache.get_all_entities_changed(1), + ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], + ) + self.assertEqual( + cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"] + ) + self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) + self.assertEqual(cache.get_all_entities_changed(0), None) + + def test_has_any_entity_changed(self): + """ + StreamChangeCache.has_any_entity_changed will return True if any + entities have been changed since the provided stream position, and + False if they have not. If the cache has entries and the provided + stream position is before it, it will return True, otherwise False if + the cache has no entries. + """ + cache = StreamChangeCache("#test", 1) + + # With no entities, it returns False for the past, present, and future. + self.assertFalse(cache.has_any_entity_changed(0)) + self.assertFalse(cache.has_any_entity_changed(1)) + self.assertFalse(cache.has_any_entity_changed(2)) + + # We add an entity + cache.entity_has_changed("user@foo.com", 2) + + # With an entity, it returns True for the past, the stream start + # position, and False for the stream position the entity was changed + # on and ones after it. + self.assertTrue(cache.has_any_entity_changed(0)) + self.assertTrue(cache.has_any_entity_changed(1)) + self.assertFalse(cache.has_any_entity_changed(2)) + self.assertFalse(cache.has_any_entity_changed(3)) + + def test_get_entities_changed(self): + """ + StreamChangeCache.get_entities_changed will return the entities in the + given list that have changed since the provided stream ID. If the + stream position is earlier than the earliest known position, it will + return all of the entities queried for. + """ + cache = StreamChangeCache("#test", 1) + + cache.entity_has_changed("user@foo.com", 2) + cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("user@elsewhere.org", 4) + + # Query all the entries, but mid-way through the stream. We should only + # get the ones after that point. + self.assertEqual( + cache.get_entities_changed( + ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2 + ), + set(["bar@baz.net", "user@elsewhere.org"]), + ) + + # Query all the entries mid-way through the stream, but include one + # that doesn't exist in it. We shouldn't get back the one that doesn't + # exist. + self.assertEqual( + cache.get_entities_changed( + [ + "user@foo.com", + "bar@baz.net", + "user@elsewhere.org", + "not@here.website", + ], + stream_pos=2, + ), + set(["bar@baz.net", "user@elsewhere.org"]), + ) + + # Query all the entries, but before the first known point. We will get + # all the entries we queried for, including ones that don't exist. + self.assertEqual( + cache.get_entities_changed( + [ + "user@foo.com", + "bar@baz.net", + "user@elsewhere.org", + "not@here.website", + ], + stream_pos=0, + ), + set( + [ + "user@foo.com", + "bar@baz.net", + "user@elsewhere.org", + "not@here.website", + ] + ), + ) + + # Query a subset of the entries mid-way through the stream. We should + # only get back the subset. + self.assertEqual( + cache.get_entities_changed( + [ + "bar@baz.net", + ], + stream_pos=2, + ), + set( + [ + "bar@baz.net", + ] + ), + ) + + def test_max_pos(self): + """ + StreamChangeCache.get_max_pos_of_last_change will return the most + recent point where the entity could have changed. If the entity is not + known, the stream start is provided instead. + """ + cache = StreamChangeCache("#test", 1) + + cache.entity_has_changed("user@foo.com", 2) + cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("user@elsewhere.org", 4) + + # Known entities will return the point where they were changed. + self.assertEqual(cache.get_max_pos_of_last_change("user@foo.com"), 2) + self.assertEqual(cache.get_max_pos_of_last_change("bar@baz.net"), 3) + self.assertEqual(cache.get_max_pos_of_last_change("user@elsewhere.org"), 4) + + # Unknown entities will return the stream start position. + self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), 1) diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py index 7ab578a185..a5f2261208 100644 --- a/tests/util/test_treecache.py +++ b/tests/util/test_treecache.py @@ -14,10 +14,10 @@ # limitations under the License. -from .. import unittest - from synapse.util.caches.treecache import TreeCache +from .. import unittest + class TreeCacheTestCase(unittest.TestCase): def test_get_set_onelevel(self): diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py index c44567e52e..03201a4d9b 100644 --- a/tests/util/test_wheel_timer.py +++ b/tests/util/test_wheel_timer.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .. import unittest - from synapse.util.wheel_timer import WheelTimer +from .. import unittest + class WheelTimerTestCase(unittest.TestCase): def test_single_insert_fetch(self): @@ -33,7 +33,7 @@ class WheelTimerTestCase(unittest.TestCase): self.assertListEqual(wheel.fetch(156), [obj]) self.assertListEqual(wheel.fetch(170), []) - def test_mutli_insert(self): + def test_multi_insert(self): wheel = WheelTimer(bucket_size=5) obj1 = object() @@ -58,7 +58,7 @@ class WheelTimerTestCase(unittest.TestCase): wheel.insert(100, obj, 50) self.assertListEqual(wheel.fetch(120), [obj]) - def test_insert_past_mutli(self): + def test_insert_past_multi(self): wheel = WheelTimer(bucket_size=5) obj1 = object() diff --git a/tests/utils.py b/tests/utils.py index 4f7e32b3ab..c3dbff8507 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,35 +13,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.server import HttpServer -from synapse.api.errors import cs_error, CodeMessageException, StoreError -from synapse.api.constants import EventTypes -from synapse.storage.prepare_database import prepare_database -from synapse.storage.engines import create_engine -from synapse.server import HomeServer -from synapse.federation.transport import server -from synapse.util.ratelimitutils import FederationRateLimiter +import hashlib +from inspect import getcallargs -from synapse.util.logcontext import LoggingContext +from mock import Mock, patch +from six.moves.urllib import parse as urlparse from twisted.internet import defer, reactor -from twisted.enterprise.adbapi import ConnectionPool -from collections import namedtuple -from mock import patch, Mock -import hashlib -import urllib -import urlparse +from synapse.api.errors import CodeMessageException, cs_error +from synapse.federation.transport import server +from synapse.http.server import HttpServer +from synapse.server import HomeServer +from synapse.storage import PostgresEngine +from synapse.storage.engines import create_engine +from synapse.storage.prepare_database import prepare_database +from synapse.util.logcontext import LoggingContext +from synapse.util.ratelimitutils import FederationRateLimiter -from inspect import getcallargs +# set this to True to run the tests against postgres instead of sqlite. +# It requires you to have a local postgres database called synapse_test, within +# which ALL TABLES WILL BE DROPPED +USE_POSTGRES_FOR_TESTS = False @defer.inlineCallbacks -def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): +def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None, + **kargs): """Setup a homeserver suitable for running tests against. Keyword arguments are passed to the Homeserver constructor. If no datastore is supplied a datastore backed by an in-memory sqlite db will be given to the HS. """ + if reactor is None: + from twisted.internet import reactor + if config is None: config = Mock() config.signing_key = [MockKey()] @@ -56,34 +61,84 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.worker_replication_url = "" config.worker_app = None config.email_enable_notifs = False + config.block_non_admin_invites = False + config.federation_domain_whitelist = None + config.federation_rc_reject_limit = 10 + config.federation_rc_sleep_limit = 10 + config.federation_rc_sleep_delay = 100 + config.federation_rc_concurrent = 10 + config.filter_timeline_limit = 5000 + config.user_directory_search_all_users = False + config.user_consent_server_notice_content = None + config.block_events_without_consent_error = None + config.media_storage_providers = [] + config.auto_join_rooms = [] + + # disable user directory updates, because they get done in the + # background, which upsets the test runner. + config.update_user_directory = False config.use_frozen_dicts = True - config.database_config = {"name": "sqlite3"} config.ldap_enabled = False if "clock" not in kargs: kargs["clock"] = MockClock() + if USE_POSTGRES_FOR_TESTS: + config.database_config = { + "name": "psycopg2", + "args": { + "database": "synapse_test", + "cp_min": 1, + "cp_max": 5, + }, + } + else: + config.database_config = { + "name": "sqlite3", + "args": { + "database": ":memory:", + "cp_min": 1, + "cp_max": 1, + }, + } + + db_engine = create_engine(config.database_config) + + # we need to configure the connection pool to run the on_new_connection + # function, so that we can test code that uses custom sqlite functions + # (like rank). + config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection + if datastore is None: - db_pool = SQLiteMemoryDbPool() - yield db_pool.prepare() hs = HomeServer( - name, db_pool=db_pool, config=config, + name, config=config, + db_config=config.database_config, version_string="Synapse/tests", - database_engine=create_engine(config.database_config), - get_db_conn=db_pool.get_db_conn, + database_engine=db_engine, room_list_handler=object(), tls_server_context_factory=Mock(), + reactor=reactor, **kargs ) + db_conn = hs.get_db_conn() + # make sure that the database is empty + if isinstance(db_engine, PostgresEngine): + cur = db_conn.cursor() + cur.execute("SELECT tablename FROM pg_tables where schemaname='public'") + rows = cur.fetchall() + for r in rows: + cur.execute("DROP TABLE %s CASCADE" % r[0]) + yield prepare_database(db_conn, db_engine, config) hs.setup() else: hs = HomeServer( name, db_pool=None, datastore=datastore, config=config, version_string="Synapse/tests", - database_engine=create_engine(config.database_config), + database_engine=db_engine, room_list_handler=object(), tls_server_context_factory=Mock(), + reactor=reactor, **kargs ) @@ -172,7 +227,7 @@ class MockHttpResource(HttpServer): headers = {} if federation_auth: - headers["Authorization"] = ["X-Matrix origin=test,key=,sig="] + headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="] mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) # return the right path if the event requires it @@ -183,7 +238,7 @@ class MockHttpResource(HttpServer): mock_request.args = urlparse.parse_qs(path.split('?')[1]) mock_request.path = path.split('?')[0] path = mock_request.path - except: + except Exception: pass for (method, pattern, func) in self.callbacks: @@ -194,7 +249,7 @@ class MockHttpResource(HttpServer): if matcher: try: args = [ - urllib.unquote(u).decode("UTF-8") + urlparse.unquote(u).decode("UTF-8") for u in matcher.groups() ] @@ -300,167 +355,6 @@ class MockClock(object): return d -class SQLiteMemoryDbPool(ConnectionPool, object): - def __init__(self): - super(SQLiteMemoryDbPool, self).__init__( - "sqlite3", ":memory:", - cp_min=1, - cp_max=1, - ) - - self.config = Mock() - self.config.database_config = {"name": "sqlite3"} - - def prepare(self): - engine = self.create_engine() - return self.runWithConnection( - lambda conn: prepare_database(conn, engine, self.config) - ) - - def get_db_conn(self): - conn = self.connect() - engine = self.create_engine() - prepare_database(conn, engine, self.config) - return conn - - def create_engine(self): - return create_engine(self.config.database_config) - - -class MemoryDataStore(object): - - Room = namedtuple( - "Room", - ["room_id", "is_public", "creator"] - ) - - def __init__(self): - self.tokens_to_users = {} - self.paths_to_content = {} - - self.members = {} - self.rooms = {} - - self.current_state = {} - self.events = [] - - class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")): - def fill_out_prev_events(self, event): - pass - - def snapshot_room(self, room_id, user_id, state_type=None, state_key=None): - return self.Snapshot( - room_id, user_id, self.get_room_member(user_id, room_id) - ) - - def register(self, user_id, token, password_hash): - if user_id in self.tokens_to_users.values(): - raise StoreError(400, "User in use.") - self.tokens_to_users[token] = user_id - - def get_user_by_access_token(self, token): - try: - return { - "name": self.tokens_to_users[token], - } - except: - raise StoreError(400, "User does not exist.") - - def get_room(self, room_id): - try: - return self.rooms[room_id] - except: - return None - - def store_room(self, room_id, room_creator_user_id, is_public): - if room_id in self.rooms: - raise StoreError(409, "Conflicting room!") - - room = MemoryDataStore.Room( - room_id=room_id, - is_public=is_public, - creator=room_creator_user_id - ) - self.rooms[room_id] = room - - def get_room_member(self, user_id, room_id): - return self.members.get(room_id, {}).get(user_id) - - def get_room_members(self, room_id, membership=None): - if membership: - return [ - v for k, v in self.members.get(room_id, {}).items() - if v.membership == membership - ] - else: - return self.members.get(room_id, {}).values() - - def get_rooms_for_user_where_membership_is(self, user_id, membership_list): - return [ - m[user_id] for m in self.members.values() - if user_id in m and m[user_id].membership in membership_list - ] - - def get_room_events_stream(self, user_id=None, from_key=None, to_key=None, - limit=0, with_feedback=False): - return ([], from_key) # TODO - - def get_joined_hosts_for_room(self, room_id): - return defer.succeed([]) - - def persist_event(self, event): - if event.type == EventTypes.Member: - room_id = event.room_id - user = event.state_key - self.members.setdefault(room_id, {})[user] = event - - if hasattr(event, "state_key"): - key = (event.room_id, event.type, event.state_key) - self.current_state[key] = event - - self.events.append(event) - - def get_current_state(self, room_id, event_type=None, state_key=""): - if event_type: - key = (room_id, event_type, state_key) - if self.current_state.get(key): - return [self.current_state.get(key)] - return None - else: - return [ - e for e in self.current_state - if e[0] == room_id - ] - - def set_presence_state(self, user_localpart, state): - return defer.succeed({"state": 0}) - - def get_presence_list(self, user_localpart, accepted): - return [] - - def get_room_events_max_id(self): - return "s0" # TODO (erikj) - - def get_send_event_level(self, room_id): - return defer.succeed(0) - - def get_power_level(self, room_id, user_id): - return defer.succeed(0) - - def get_add_state_level(self, room_id): - return defer.succeed(0) - - def get_room_join_rule(self, room_id): - # TODO (erikj): This should be configurable - return defer.succeed("invite") - - def get_ops_levels(self, room_id): - return defer.succeed((5, 5, 5)) - - def insert_client_ip(self, user, access_token, ip, user_agent): - return defer.succeed(None) - - def _format_call(args, kwargs): return ", ".join( ["%r" % (a) for a in args] + @@ -498,7 +392,7 @@ class DeferredMockCallable(object): for _, _, d in self.expectations: try: d.errback(failure) - except: + except Exception: pass raise failure |